File size: 9,387 Bytes
e950836 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | """
audio_model.py
==============
AASISTDeepFake model definition β matches the training notebook exactly.
Import this in both training scripts and the Gradio app (via audio_detector_inference.py).
Architecture:
Raw waveform β SincConv β Downsample (32Γ) β Res2Block
β CNN (2 layers) β GraphAttn (Γ2) β AttentionPool β Classifier
Label convention (from training dataset enumerate(["fake", "real"])):
label = 0 β Fake
label = 1 β Real
sigmoid(logit) >= threshold β Real
sigmoid(logit) < threshold β Fake
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# ββ Audio constants (must match training) βββββββββββββββββββββββββββββββββββββ
SAMPLE_RATE = 16_000
MAX_DURATION = 5.0
MAX_SAMPLES = int(SAMPLE_RATE * MAX_DURATION) # 80 000 samples
# ββ Sub-modules βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class SincConv(nn.Module):
"""
Learnable sinc-function band-pass filter bank.
Only 2Γout_channels parameters (one f_low, one f_high per filter).
Initialised from mel-scale frequency bands.
"""
@staticmethod
def to_mel(hz): return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel): return 700 * (10 ** (mel / 2595) - 1)
def __init__(self, out_channels: int = 64, kernel_size: int = 512,
sample_rate: int = 16_000):
super().__init__()
self.out_channels = out_channels
self.kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
self.sample_rate = sample_rate
low_hz, high_hz = 30, sample_rate / 2 - 100
mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), out_channels + 1)
hz = self.to_hz(mel)
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
half = (self.kernel_size - 1) // 2
n = torch.arange(1, half + 1, dtype=torch.float32)
self.register_buffer('n_', (2 * np.pi * n / sample_rate).unsqueeze(0))
self.register_buffer('window_', torch.hamming_window(self.kernel_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
low = 50 + torch.abs(self.low_hz_)
high = torch.clamp(low + 50 + torch.abs(self.band_hz_),
max=self.sample_rate / 2)
band = (high - low)[:, 0]
f1 = torch.matmul(low, self.n_)
f2 = torch.matmul(high, self.n_)
lp1 = torch.sin(f1) / (np.pi * self.n_ / (2 * np.pi))
lp2 = torch.sin(f2) / (np.pi * self.n_ / (2 * np.pi))
bp = (lp2 - lp1) / (2 * band[:, None])
centre = torch.zeros(self.out_channels, 1, device=bp.device)
filters = torch.cat([bp.flip(1), centre, bp], dim=1)
filters = filters * self.window_
x = x.unsqueeze(1)
return F.conv1d(x, filters.unsqueeze(1), padding=self.kernel_size // 2)
class Res2Block(nn.Module):
"""
Multi-scale residual block with inter-group accumulation.
Splits channels into `scale` groups; each group accumulates the previous.
"""
def __init__(self, channels: int, scale: int = 8, dilation: int = 1):
super().__init__()
assert channels % scale == 0, \
f"channels ({channels}) must be divisible by scale ({scale})"
self.scale = scale
width = channels // scale
self.convs = nn.ModuleList([
nn.Conv1d(width, width, 3, padding=dilation, dilation=dilation)
for _ in range(scale - 1)
])
self.bns = nn.ModuleList([nn.BatchNorm1d(width) for _ in range(scale - 1)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
chunks = torch.chunk(x, self.scale, dim=1)
out = [chunks[0]]
y = chunks[1]
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
if i > 0:
y = y + chunks[i + 1]
y = F.gelu(bn(conv(y)))
out.append(y)
return torch.cat(out, dim=1)
class GraphAttn(nn.Module):
"""
Memory-efficient multi-head self-attention over temporal frames.
Sequences longer than 64 tokens are pooled before attention and
upsampled back for the residual addition.
"""
def __init__(self, dim: int, heads: int = 4):
super().__init__()
self.heads = heads
self.head_dim = dim // heads
self.qkv = nn.Linear(dim, dim * 3)
self.out = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
if N > 64:
x_pool = F.adaptive_avg_pool1d(
x.transpose(1, 2), 64).transpose(1, 2)
else:
x_pool = x
Bp, Np, Cp = x_pool.shape
qkv = (self.qkv(x_pool)
.reshape(Bp, Np, 3, self.heads, self.head_dim)
.permute(2, 0, 3, 1, 4))
q, k, v = qkv.unbind(0)
attn = torch.softmax(
q @ k.transpose(-2, -1) / (self.head_dim ** 0.5), dim=-1)
out = (attn @ v).transpose(1, 2).reshape(Bp, Np, Cp)
out = self.out(out)
# Upsample back to original length for the residual connection
out = F.interpolate(
out.transpose(1, 2), size=N,
mode='linear', align_corners=False).transpose(1, 2)
return out
class AttentionPool(nn.Module):
"""Soft-attention weighted pooling over a sequence."""
def __init__(self, dim: int):
super().__init__()
self.attn = nn.Linear(dim, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = torch.softmax(self.attn(x), dim=1) # (B, T, 1)
return (w * x).sum(dim=1) # (B, dim)
# ββ Main model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class AASISTDeepFake(nn.Module):
"""
AASISTDeepFake β memory-efficient raw-waveform audio spoof detector.
Input : (B, 80 000) float32 waveform, normalised to [-1, 1]
Output : (B, 1) raw logit β sigmoid β P(real)
Prediction:
sigmoid(logit) >= threshold β Real (label 1)
sigmoid(logit) < threshold β Fake (label 0)
"""
def __init__(
self,
sinc_ch: int = 64,
sinc_kernel: int = 512,
hidden: int = 128,
graph_heads: int = 4,
n_graph: int = 2,
):
super().__init__()
self.sinc = SincConv(sinc_ch, sinc_kernel, SAMPLE_RATE)
self.bn_sinc = nn.BatchNorm1d(sinc_ch)
# Aggressive downsampling: T β T/32 (kills OOM on long sequences)
self.downsample = nn.Sequential(
nn.Conv1d(sinc_ch, sinc_ch, kernel_size=8, stride=8),
nn.BatchNorm1d(sinc_ch), nn.GELU(),
nn.Conv1d(sinc_ch, sinc_ch, kernel_size=4, stride=4),
nn.BatchNorm1d(sinc_ch), nn.GELU(),
)
self.encoder = nn.Sequential(
Res2Block(sinc_ch), nn.BatchNorm1d(sinc_ch), nn.GELU(),
)
self.cnn = nn.Sequential(
nn.Conv1d(sinc_ch, hidden, kernel_size=3, padding=1),
nn.BatchNorm1d(hidden), nn.GELU(),
nn.Conv1d(hidden, hidden, kernel_size=3, padding=1),
nn.BatchNorm1d(hidden), nn.GELU(),
)
self.graph_layers = nn.ModuleList(
[GraphAttn(hidden, graph_heads) for _ in range(n_graph)])
self.layer_norms = nn.ModuleList(
[nn.LayerNorm(hidden) for _ in range(n_graph)])
self.pool = AttentionPool(hidden)
self.classifier = nn.Sequential(
nn.LayerNorm(hidden),
nn.Linear(hidden, 64),
nn.GELU(),
nn.Dropout(0.4),
nn.Linear(64, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.abs(self.sinc(x)) # (B, sinc_ch, T)
x = F.gelu(self.bn_sinc(x))
x = self.downsample(x) # (B, sinc_ch, T/32)
x = self.encoder(x)
x = self.cnn(x) # (B, hidden, T/32)
x = x.transpose(1, 2) # (B, T/32, hidden)
for attn, ln in zip(self.graph_layers, self.layer_norms):
x = ln(x + attn(x))
pooled = self.pool(x) # (B, hidden)
return self.classifier(pooled) # (B, 1)
# ββ Helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_audio_model(
checkpoint: str,
device: torch.device = None,
) -> AASISTDeepFake:
"""Load a trained AASISTDeepFake from a .pt state-dict checkpoint."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AASISTDeepFake()
model.load_state_dict(torch.load(checkpoint, map_location=device))
model.eval().to(device)
return model
|