JunSiang26's picture
Pure production deploy
9f2b6db
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Wav2Vec2Model
# ============================================================
# 1. Wav2Vec2 Detector (Self-supervised Transformer Baseline)
# ============================================================
class AttentivePooling(nn.Module):
def __init__(self, dim):
super().__init__()
self.attn = nn.Sequential(
nn.Linear(dim, dim),
nn.Tanh(),
nn.Linear(dim, 1)
)
def forward(self, x):
w = torch.softmax(self.attn(x), dim=1)
return torch.sum(w * x, dim=1)
class Wav2Vec2SpoofDetector(nn.Module):
def __init__(self, num_classes=2, model_name="facebook/wav2vec2-base"):
super().__init__()
self.wav2vec = Wav2Vec2Model.from_pretrained(model_name)
#freeze model
for param in self.wav2vec.parameters():
param.requires_grad = False
hidden = self.wav2vec.config.hidden_size
self.pool = AttentivePooling(hidden)
self.classifier = nn.Sequential(
nn.LayerNorm(hidden),
nn.Dropout(0.2),
nn.Linear(hidden, num_classes)
)
def forward(self, x):
if x.dim() == 3:
x = x.squeeze(1)
out = self.wav2vec(x).last_hidden_state
pooled = self.pool(out)
return self.classifier(pooled)
# ============================================================
# 2. AASIST (SOTA Graph-based Baseline)
# ============================================================
import random
from typing import Union
import numpy as np
from torch import Tensor
# Original simplistic Graph Attention/Block kept for the Custom model dependent on it
class GraphAttention(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.fc = nn.Linear(in_dim, out_dim)
self.attn = nn.Linear(out_dim * 2, 1)
def forward(self, x):
h = self.fc(x)
# Instead of allocating O(N^2 * D) tensor arrays for pairwise combinations,
# we can decompose the linear attention matrix and use broadcasting!
# Memory consumption goes from ~10GB on N=400 to ~2MB.
W = self.attn.weight.squeeze()
D = h.shape[-1]
W_1 = W[:D]
W_2 = W[D:]
# Compute individual node scores: shape (B, N, 1)
score_i = torch.matmul(h, W_1).unsqueeze(-1)
score_j = torch.matmul(h, W_2).unsqueeze(-1)
# Broadcast (B, N, 1) + (B, 1, N) -> (B, N, N)
e = score_i + score_j.transpose(1, 2)
if self.attn.bias is not None:
e = e + self.attn.bias
alpha = F.softmax(e, dim=-1)
out = torch.matmul(alpha, h)
return out
class GraphBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.gat = GraphAttention(dim, dim)
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
res = x
x = self.gat(x)
x = self.dropout(x)
x = self.norm(x + res)
return x
class GraphAttentionLayer(nn.Module):
def __init__(self, in_dim, out_dim, **kwargs):
super().__init__()
# attention map
self.att_proj = nn.Linear(in_dim, out_dim)
self.att_weight = self._init_new_params(out_dim, 1)
# project
self.proj_with_att = nn.Linear(in_dim, out_dim)
self.proj_without_att = nn.Linear(in_dim, out_dim)
# batch norm
self.bn = nn.BatchNorm1d(out_dim)
# dropout for inputs
self.input_drop = nn.Dropout(p=0.2)
# activate
self.act = nn.SELU(inplace=True)
# temperature
self.temp = 1.
if "temperature" in kwargs:
self.temp = kwargs["temperature"]
def forward(self, x):
'''
x :(#bs, #node, #dim)
'''
# apply input dropout
x = self.input_drop(x)
# derive attention map
att_map = self._derive_att_map(x)
# projection
x = self._project(x, att_map)
# apply batch norm
x = self._apply_BN(x)
x = self.act(x)
return x
def _pairwise_mul_nodes(self, x):
'''
Calculates pairwise multiplication of nodes.
- for attention map
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, #dim)
'''
nb_nodes = x.size(1)
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
x_mirror = x.transpose(1, 2)
return x * x_mirror
def _derive_att_map(self, x):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = self._pairwise_mul_nodes(x)
# size: (#bs, #node, #node, #dim_out)
att_map = torch.tanh(self.att_proj(att_map))
# size: (#bs, #node, #node, 1)
att_map = torch.matmul(att_map, self.att_weight)
# apply temperature
att_map = att_map / self.temp
att_map = F.softmax(att_map, dim=-2)
return att_map
def _project(self, x, att_map):
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
x2 = self.proj_without_att(x)
return x1 + x2
def _apply_BN(self, x):
org_size = x.size()
x = x.view(-1, org_size[-1])
x = self.bn(x)
x = x.view(org_size)
return x
def _init_new_params(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
class HtrgGraphAttentionLayer(nn.Module):
def __init__(self, in_dim, out_dim, **kwargs):
super().__init__()
self.proj_type1 = nn.Linear(in_dim, in_dim)
self.proj_type2 = nn.Linear(in_dim, in_dim)
# attention map
self.att_proj = nn.Linear(in_dim, out_dim)
self.att_projM = nn.Linear(in_dim, out_dim)
self.att_weight11 = self._init_new_params(out_dim, 1)
self.att_weight22 = self._init_new_params(out_dim, 1)
self.att_weight12 = self._init_new_params(out_dim, 1)
self.att_weightM = self._init_new_params(out_dim, 1)
# project
self.proj_with_att = nn.Linear(in_dim, out_dim)
self.proj_without_att = nn.Linear(in_dim, out_dim)
self.proj_with_attM = nn.Linear(in_dim, out_dim)
self.proj_without_attM = nn.Linear(in_dim, out_dim)
# batch norm
self.bn = nn.BatchNorm1d(out_dim)
# dropout for inputs
self.input_drop = nn.Dropout(p=0.2)
# activate
self.act = nn.SELU(inplace=True)
# temperature
self.temp = 1.
if "temperature" in kwargs:
self.temp = kwargs["temperature"]
def forward(self, x1, x2, master=None):
'''
x1 :(#bs, #node, #dim)
x2 :(#bs, #node, #dim)
'''
num_type1 = x1.size(1)
num_type2 = x2.size(1)
x1 = self.proj_type1(x1)
x2 = self.proj_type2(x2)
x = torch.cat([x1, x2], dim=1)
if master is None:
master = torch.mean(x, dim=1, keepdim=True)
# apply input dropout
x = self.input_drop(x)
# derive attention map
att_map = self._derive_att_map(x, num_type1, num_type2)
# directional edge for master node
master = self._update_master(x, master)
# projection
x = self._project(x, att_map)
# apply batch norm
x = self._apply_BN(x)
x = self.act(x)
x1 = x.narrow(1, 0, num_type1)
x2 = x.narrow(1, num_type1, num_type2)
return x1, x2, master
def _update_master(self, x, master):
att_map = self._derive_att_map_master(x, master)
master = self._project_master(x, master, att_map)
return master
def _pairwise_mul_nodes(self, x):
'''
Calculates pairwise multiplication of nodes.
- for attention map
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, #dim)
'''
nb_nodes = x.size(1)
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
x_mirror = x.transpose(1, 2)
return x * x_mirror
def _derive_att_map_master(self, x, master):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = x * master
att_map = torch.tanh(self.att_projM(att_map))
att_map = torch.matmul(att_map, self.att_weightM)
# apply temperature
att_map = att_map / self.temp
att_map = F.softmax(att_map, dim=-2)
return att_map
def _derive_att_map(self, x, num_type1, num_type2):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = self._pairwise_mul_nodes(x)
# size: (#bs, #node, #node, #dim_out)
att_map = torch.tanh(self.att_proj(att_map))
# size: (#bs, #node, #node, 1)
att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
att_board[:, :num_type1, :num_type1, :] = torch.matmul(
att_map[:, :num_type1, :num_type1, :], self.att_weight11)
att_board[:, num_type1:, num_type1:, :] = torch.matmul(
att_map[:, num_type1:, num_type1:, :], self.att_weight22)
att_board[:, :num_type1, num_type1:, :] = torch.matmul(
att_map[:, :num_type1, num_type1:, :], self.att_weight12)
att_board[:, num_type1:, :num_type1, :] = torch.matmul(
att_map[:, num_type1:, :num_type1, :], self.att_weight12)
att_map = att_board
# apply temperature
att_map = att_map / self.temp
att_map = F.softmax(att_map, dim=-2)
return att_map
def _project(self, x, att_map):
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
x2 = self.proj_without_att(x)
return x1 + x2
def _project_master(self, x, master, att_map):
x1 = self.proj_with_attM(torch.matmul(
att_map.squeeze(-1).unsqueeze(1), x))
x2 = self.proj_without_attM(master)
return x1 + x2
def _apply_BN(self, x):
org_size = x.size()
x = x.view(-1, org_size[-1])
x = self.bn(x)
x = x.view(org_size)
return x
def _init_new_params(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
class GraphPool(nn.Module):
def __init__(self, k: float, in_dim: int, p: Union[float, int]):
super().__init__()
self.k = k
self.sigmoid = nn.Sigmoid()
self.proj = nn.Linear(in_dim, 1)
self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
self.in_dim = in_dim
def forward(self, h):
Z = self.drop(h)
weights = self.proj(Z)
scores = self.sigmoid(weights)
new_h = self.top_k_graph(scores, h, self.k)
return new_h
def top_k_graph(self, scores, h, k):
_, n_nodes, n_feat = h.size()
n_nodes = max(int(n_nodes * k), 1)
_, idx = torch.topk(scores, n_nodes, dim=1)
idx = idx.expand(-1, -1, n_feat)
h = h * scores
h = torch.gather(h, 1, idx)
return h
class CONV(nn.Module):
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10**(mel / 2595) - 1)
def __init__(self,
out_channels,
kernel_size,
sample_rate=16000,
in_channels=1,
stride=1,
padding=0,
dilation=1,
bias=False,
groups=1,
mask=False):
super().__init__()
if in_channels != 1:
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
raise ValueError(msg)
self.out_channels = out_channels
self.kernel_size = kernel_size
self.sample_rate = sample_rate
# Forcing the filters to be odd (i.e, perfectly symmetrics)
if kernel_size % 2 == 0:
self.kernel_size = self.kernel_size + 1
self.stride = stride
self.padding = padding
self.dilation = dilation
self.mask = mask
if bias:
raise ValueError('SincConv does not support bias.')
if groups > 1:
raise ValueError('SincConv does not support groups.')
NFFT = 512
f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
fmel = self.to_mel(f)
fmelmax = np.max(fmel)
fmelmin = np.min(fmel)
filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
filbandwidthsf = self.to_hz(filbandwidthsmel)
self.mel = filbandwidthsf
self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
(self.kernel_size - 1) / 2 + 1)
self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
for i in range(len(self.mel) - 1):
fmin = self.mel[i]
fmax = self.mel[i + 1]
hHigh = (2*fmax/self.sample_rate) * \
np.sinc(2*fmax*self.hsupp/self.sample_rate)
hLow = (2*fmin/self.sample_rate) * \
np.sinc(2*fmin*self.hsupp/self.sample_rate)
hideal = hHigh - hLow
self.band_pass[i, :] = Tensor(np.hamming(
self.kernel_size)) * Tensor(hideal)
def forward(self, x, mask=False):
band_pass_filter = self.band_pass.clone().to(x.device)
if mask:
A = np.random.uniform(0, 20)
A = int(A)
A0 = random.randint(0, band_pass_filter.shape[0] - A)
band_pass_filter[A0:A0 + A, :] = 0
else:
band_pass_filter = band_pass_filter
self.filters = (band_pass_filter).view(self.out_channels, 1,
self.kernel_size)
return F.conv1d(x,
self.filters,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=None,
groups=1)
class Residual_block(nn.Module):
def __init__(self, nb_filts, first=False):
super().__init__()
self.first = first
if not self.first:
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(1, 1),
stride=1)
self.selu = nn.SELU(inplace=True)
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(0, 1),
stride=1)
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=(0, 1),
kernel_size=(1, 3),
stride=1)
else:
self.downsample = False
self.mp = nn.MaxPool2d((1, 3))
def forward(self, x):
identity = x
if not self.first:
out = self.bn1(x)
out = self.selu(out)
else:
out = x
out = self.conv1(x)
out = self.bn2(out)
out = self.selu(out)
out = self.conv2(out)
if self.downsample:
identity = self.conv_downsample(identity)
out += identity
out = self.mp(out)
return out
class AASISTModel(nn.Module):
def __init__(self, d_args):
super().__init__()
self.d_args = d_args
filts = d_args["filts"]
gat_dims = d_args["gat_dims"]
pool_ratios = d_args["pool_ratios"]
temperatures = d_args["temperatures"]
self.conv_time = CONV(out_channels=filts[0],
kernel_size=d_args["first_conv"],
in_channels=1)
self.first_bn = nn.BatchNorm2d(num_features=1)
self.drop = nn.Dropout(0.5, inplace=True)
self.drop_way = nn.Dropout(0.2, inplace=True)
self.selu = nn.SELU(inplace=True)
self.encoder = nn.Sequential(
nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
nn.Sequential(Residual_block(nb_filts=filts[2])),
nn.Sequential(Residual_block(nb_filts=filts[3])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])))
self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
gat_dims[0],
temperature=temperatures[0])
self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
gat_dims[0],
temperature=temperatures[1])
self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
gat_dims[0], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
gat_dims[1], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
gat_dims[0], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
gat_dims[1], gat_dims[1], temperature=temperatures[2])
self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.out_layer = nn.Linear(5 * gat_dims[1], 2)
def forward(self, x, Freq_aug=False):
x = x.unsqueeze(1)
x = self.conv_time(x, mask=Freq_aug)
x = x.unsqueeze(dim=1)
x = F.max_pool2d(torch.abs(x), (3, 3))
x = self.first_bn(x)
x = self.selu(x)
e = self.encoder(x)
e_S, _ = torch.max(torch.abs(e), dim=3)
e_S = e_S.transpose(1, 2) + self.pos_S
gat_S = self.GAT_layer_S(e_S)
out_S = self.pool_S(gat_S)
e_T, _ = torch.max(torch.abs(e), dim=2)
e_T = e_T.transpose(1, 2)
gat_T = self.GAT_layer_T(e_T)
out_T = self.pool_T(gat_T)
master1 = self.master1.expand(x.size(0), -1, -1)
master2 = self.master2.expand(x.size(0), -1, -1)
out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
out_T, out_S, master=self.master1)
out_S1 = self.pool_hS1(out_S1)
out_T1 = self.pool_hT1(out_T1)
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
out_T1, out_S1, master=master1)
out_T1 = out_T1 + out_T_aug
out_S1 = out_S1 + out_S_aug
master1 = master1 + master_aug
out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
out_T, out_S, master=self.master2)
out_S2 = self.pool_hS2(out_S2)
out_T2 = self.pool_hT2(out_T2)
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
out_T2, out_S2, master=master2)
out_T2 = out_T2 + out_T_aug
out_S2 = out_S2 + out_S_aug
master2 = master2 + master_aug
out_T1 = self.drop_way(out_T1)
out_T2 = self.drop_way(out_T2)
out_S1 = self.drop_way(out_S1)
out_S2 = self.drop_way(out_S2)
master1 = self.drop_way(master1)
master2 = self.drop_way(master2)
out_T = torch.max(out_T1, out_T2)
out_S = torch.max(out_S1, out_S2)
master = torch.max(master1, master2)
T_max, _ = torch.max(torch.abs(out_T), dim=1)
T_avg = torch.mean(out_T, dim=1)
S_max, _ = torch.max(torch.abs(out_S), dim=1)
S_avg = torch.mean(out_S, dim=1)
last_hidden = torch.cat(
[T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
last_hidden = self.drop(last_hidden)
output = self.out_layer(last_hidden)
return last_hidden, output
class AASISTDetector(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
d_args = {
"nb_samp": 64600,
"first_conv": 128,
"in_channels": 1,
"filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
"gat_dims": [64, 32],
"pool_ratios": [0.5, 0.7, 0.5, 0.5],
"temperatures": [2.0, 2.0, 100.0]
}
self.model = AASISTModel(d_args)
# Override out_layer if not strictly 2 classes.
if num_classes != 2:
self.model.out_layer = nn.Linear(5 * d_args["gat_dims"][1], num_classes)
def forward(self, x):
# x is (B, 1, T) or (B, T)
if x.dim() == 3:
x = x.squeeze(1) # Convert to (B, T)
_, out = self.model(x)
return out
# ============================================================
# 3. CQCC Baseline Detector (Acoustic Feature Baseline)
# ============================================================
class CQCCBaselineDetector(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
# Input shape expected: (B, 1, 20, T)
self.features = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(64, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.flatten(1)
return self.classifier(x)
# ============================================================
# 4. Custom Fusional Wav2Vec2 + CQCC with Cross-Attention + Graph
# ============================================================
class PositionalEncoding(nn.Module):
def __init__(self, dim, max_len=6000):
super().__init__()
self.pos_embed = nn.Parameter(torch.randn(1, max_len, dim))
def forward(self, x):
return x + self.pos_embed[:, :x.size(1)]
class BidirectionalCrossAttention(nn.Module):
def __init__(self, dim, num_heads=4):
super().__init__()
self.attn1 = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=0.2)
self.attn2 = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=0.2)
self.norm_q = nn.LayerNorm(dim)
self.norm_kv = nn.LayerNorm(dim)
def forward(self, x1, x2):
# x1 attends to x2
q1 = self.norm_q(x1)
k2 = self.norm_kv(x2)
v2 = k2
out1, _ = self.attn1(q1, k2, v2)
# x2 attends to x1
q2 = self.norm_q(x2)
k1 = self.norm_kv(x1)
v1 = k1
out2, _ = self.attn2(q2, k1, v1)
return out1, out2
def align_sequences(x, target_len):
"""Linear interpolation to match sequence lengths"""
x = x.transpose(1, 2)
x = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
return x.transpose(1, 2)
class ImprovedWav2Vec2CQCCDetector(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
# Wav2Vec2
self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
# Freeze the Wav2Vec2 layer so it acts purely as a feature extractor
for param in self.wav2vec.parameters():
param.requires_grad = False
dim = self.wav2vec.config.hidden_size
# CQCC encoder
self.cqcc_conv = nn.Sequential(
nn.Conv1d(20, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.GELU(),
nn.Dropout(0.2),
nn.Conv1d(128, dim, kernel_size=3, padding=1),
nn.BatchNorm1d(dim),
nn.GELU()
)
# Positional Encoding
self.pos_enc = PositionalEncoding(dim)
# Bidirectional Cross Attention
self.cross_attn = BidirectionalCrossAttention(dim)
# True Graph Transformer Backend (using GAT blocks from AASIST)
self.graph_layers = nn.ModuleList([
GraphBlock(dim) for _ in range(3)
])
# Classifier
self.classifier = nn.Sequential(
nn.Linear(dim, 128),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
def forward(self, wav, cqcc):
if wav.dim() == 3:
wav = wav.squeeze(1)
# Wav2Vec2 features
w2v = self.wav2vec(wav).last_hidden_state # (B, T_w, D)
# CQCC features
if cqcc.dim() == 4:
cqcc = cqcc.squeeze(1)
cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2) # (B, T_c, D)
# Align lengths
cqcc_feat = align_sequences(cqcc_feat, w2v.size(1))
# Add positional encoding
w2v = self.pos_enc(w2v)
cqcc_feat = self.pos_enc(cqcc_feat)
# Cross attention (bidirectional)
f1, f2 = self.cross_attn(cqcc_feat, w2v)
fused = f1 + f2
# Graph Transformer processing on node sequences
x = fused
for layer in self.graph_layers:
x = layer(x)
# Global average pooling on the nodes
pooled = x.mean(dim=1)
return self.classifier(pooled)
# ============================================================
# 5. Ablation Models
# ============================================================
class AblationWav2Vec2GraphDetector(nn.Module):
"""Ablation 1: Wav2Vec2 only + Graph Backend (No CQCC, No Cross-Attention)"""
def __init__(self, num_classes=2):
super().__init__()
self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
for param in self.wav2vec.parameters():
param.requires_grad = False
dim = self.wav2vec.config.hidden_size
self.pos_enc = PositionalEncoding(dim)
self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)])
self.classifier = nn.Sequential(
nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes)
)
def forward(self, wav, cqcc=None): # Accept both but ignore CQCC
if wav.dim() == 3:
wav = wav.squeeze(1)
w2v = self.wav2vec(wav).last_hidden_state
w2v = self.pos_enc(w2v)
x = w2v
for layer in self.graph_layers:
x = layer(x)
pooled = x.mean(dim=1)
return self.classifier(pooled)
class AblationCQCCGraphDetector(nn.Module):
"""Ablation 2: CQCC only + Graph Backend (No Wav2Vec2, No Cross-Attention)"""
def __init__(self, num_classes=2):
super().__init__()
dim = 768 # Match Wav2Vec2 hidden size for fair comparison
self.cqcc_conv = nn.Sequential(
nn.Conv1d(20, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.GELU(),
nn.Dropout(0.2),
nn.Conv1d(128, dim, kernel_size=3, padding=1),
nn.BatchNorm1d(dim),
nn.GELU()
)
self.pos_enc = PositionalEncoding(dim)
self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)])
self.classifier = nn.Sequential(
nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes)
)
def forward(self, cqcc):
if cqcc.dim() == 4:
cqcc = cqcc.squeeze(1)
cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2)
cqcc_feat = self.pos_enc(cqcc_feat)
x = cqcc_feat
for layer in self.graph_layers:
x = layer(x)
pooled = x.mean(dim=1)
return self.classifier(pooled)
class AblationConcatGraphDetector(nn.Module):
"""Ablation 3: Wav2Vec2 + CQCC + Simple Concat Fusion + Graph Backend (No Cross-Attention)"""
def __init__(self, num_classes=2):
super().__init__()
self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
for param in self.wav2vec.parameters():
param.requires_grad = False
dim = self.wav2vec.config.hidden_size
self.cqcc_conv = nn.Sequential(
nn.Conv1d(20, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.GELU(),
nn.Dropout(0.2),
nn.Conv1d(128, dim, kernel_size=3, padding=1),
nn.BatchNorm1d(dim),
nn.GELU()
)
self.fusion_proj = nn.Linear(dim * 2, dim) # Project concatenated features back to dim
self.pos_enc = PositionalEncoding(dim)
self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)])
self.classifier = nn.Sequential(
nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes)
)
def forward(self, wav, cqcc):
if wav.dim() == 3:
wav = wav.squeeze(1)
w2v = self.wav2vec(wav).last_hidden_state
if cqcc.dim() == 4:
cqcc = cqcc.squeeze(1)
cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2)
cqcc_feat = align_sequences(cqcc_feat, w2v.size(1))
# Simple concat over feature dimension instead of cross-attention
fused = torch.cat([w2v, cqcc_feat], dim=-1)
fused = self.fusion_proj(fused)
fused = self.pos_enc(fused)
x = fused
for layer in self.graph_layers:
x = layer(x)
pooled = x.mean(dim=1)
return self.classifier(pooled)
class AblationCrossAttnLinearDetector(nn.Module):
"""Ablation 4: Wav2Vec2 + CQCC + Cross-Attention + Linear Backend (No Graph Transformer)"""
def __init__(self, num_classes=2):
super().__init__()
self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
for param in self.wav2vec.parameters():
param.requires_grad = False
dim = self.wav2vec.config.hidden_size
self.cqcc_conv = nn.Sequential(
nn.Conv1d(20, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.GELU(),
nn.Dropout(0.2),
nn.Conv1d(128, dim, kernel_size=3, padding=1),
nn.BatchNorm1d(dim),
nn.GELU()
)
self.pos_enc = PositionalEncoding(dim)
self.cross_attn = BidirectionalCrossAttention(dim)
# Richer MLP classifier since graph is missing
self.classifier = nn.Sequential(
nn.Linear(dim, 256),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
def forward(self, wav, cqcc):
if wav.dim() == 3:
wav = wav.squeeze(1)
w2v = self.wav2vec(wav).last_hidden_state
if cqcc.dim() == 4:
cqcc = cqcc.squeeze(1)
cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2)
cqcc_feat = align_sequences(cqcc_feat, w2v.size(1))
w2v = self.pos_enc(w2v)
cqcc_feat = self.pos_enc(cqcc_feat)
f1, f2 = self.cross_attn(cqcc_feat, w2v)
fused = f1 + f2
# No graph layer, straight to global average pooling
pooled = fused.mean(dim=1)
return self.classifier(pooled)