Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import Wav2Vec2Model | |
| # ============================================================ | |
| # 1. Wav2Vec2 Detector (Self-supervised Transformer Baseline) | |
| # ============================================================ | |
| class AttentivePooling(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.attn = nn.Sequential( | |
| nn.Linear(dim, dim), | |
| nn.Tanh(), | |
| nn.Linear(dim, 1) | |
| ) | |
| def forward(self, x): | |
| w = torch.softmax(self.attn(x), dim=1) | |
| return torch.sum(w * x, dim=1) | |
| class Wav2Vec2SpoofDetector(nn.Module): | |
| def __init__(self, num_classes=2, model_name="facebook/wav2vec2-base"): | |
| super().__init__() | |
| self.wav2vec = Wav2Vec2Model.from_pretrained(model_name) | |
| #freeze model | |
| for param in self.wav2vec.parameters(): | |
| param.requires_grad = False | |
| hidden = self.wav2vec.config.hidden_size | |
| self.pool = AttentivePooling(hidden) | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(hidden), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden, num_classes) | |
| ) | |
| def forward(self, x): | |
| if x.dim() == 3: | |
| x = x.squeeze(1) | |
| out = self.wav2vec(x).last_hidden_state | |
| pooled = self.pool(out) | |
| return self.classifier(pooled) | |
| # ============================================================ | |
| # 2. AASIST (SOTA Graph-based Baseline) | |
| # ============================================================ | |
| import random | |
| from typing import Union | |
| import numpy as np | |
| from torch import Tensor | |
| # Original simplistic Graph Attention/Block kept for the Custom model dependent on it | |
| class GraphAttention(nn.Module): | |
| def __init__(self, in_dim, out_dim): | |
| super().__init__() | |
| self.fc = nn.Linear(in_dim, out_dim) | |
| self.attn = nn.Linear(out_dim * 2, 1) | |
| def forward(self, x): | |
| h = self.fc(x) | |
| # Instead of allocating O(N^2 * D) tensor arrays for pairwise combinations, | |
| # we can decompose the linear attention matrix and use broadcasting! | |
| # Memory consumption goes from ~10GB on N=400 to ~2MB. | |
| W = self.attn.weight.squeeze() | |
| D = h.shape[-1] | |
| W_1 = W[:D] | |
| W_2 = W[D:] | |
| # Compute individual node scores: shape (B, N, 1) | |
| score_i = torch.matmul(h, W_1).unsqueeze(-1) | |
| score_j = torch.matmul(h, W_2).unsqueeze(-1) | |
| # Broadcast (B, N, 1) + (B, 1, N) -> (B, N, N) | |
| e = score_i + score_j.transpose(1, 2) | |
| if self.attn.bias is not None: | |
| e = e + self.attn.bias | |
| alpha = F.softmax(e, dim=-1) | |
| out = torch.matmul(alpha, h) | |
| return out | |
| class GraphBlock(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.gat = GraphAttention(dim, dim) | |
| self.norm = nn.LayerNorm(dim) | |
| self.dropout = nn.Dropout(0.2) | |
| def forward(self, x): | |
| res = x | |
| x = self.gat(x) | |
| x = self.dropout(x) | |
| x = self.norm(x + res) | |
| return x | |
| class GraphAttentionLayer(nn.Module): | |
| def __init__(self, in_dim, out_dim, **kwargs): | |
| super().__init__() | |
| # attention map | |
| self.att_proj = nn.Linear(in_dim, out_dim) | |
| self.att_weight = self._init_new_params(out_dim, 1) | |
| # project | |
| self.proj_with_att = nn.Linear(in_dim, out_dim) | |
| self.proj_without_att = nn.Linear(in_dim, out_dim) | |
| # batch norm | |
| self.bn = nn.BatchNorm1d(out_dim) | |
| # dropout for inputs | |
| self.input_drop = nn.Dropout(p=0.2) | |
| # activate | |
| self.act = nn.SELU(inplace=True) | |
| # temperature | |
| self.temp = 1. | |
| if "temperature" in kwargs: | |
| self.temp = kwargs["temperature"] | |
| def forward(self, x): | |
| ''' | |
| x :(#bs, #node, #dim) | |
| ''' | |
| # apply input dropout | |
| x = self.input_drop(x) | |
| # derive attention map | |
| att_map = self._derive_att_map(x) | |
| # projection | |
| x = self._project(x, att_map) | |
| # apply batch norm | |
| x = self._apply_BN(x) | |
| x = self.act(x) | |
| return x | |
| def _pairwise_mul_nodes(self, x): | |
| ''' | |
| Calculates pairwise multiplication of nodes. | |
| - for attention map | |
| x :(#bs, #node, #dim) | |
| out_shape :(#bs, #node, #node, #dim) | |
| ''' | |
| nb_nodes = x.size(1) | |
| x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) | |
| x_mirror = x.transpose(1, 2) | |
| return x * x_mirror | |
| def _derive_att_map(self, x): | |
| ''' | |
| x :(#bs, #node, #dim) | |
| out_shape :(#bs, #node, #node, 1) | |
| ''' | |
| att_map = self._pairwise_mul_nodes(x) | |
| # size: (#bs, #node, #node, #dim_out) | |
| att_map = torch.tanh(self.att_proj(att_map)) | |
| # size: (#bs, #node, #node, 1) | |
| att_map = torch.matmul(att_map, self.att_weight) | |
| # apply temperature | |
| att_map = att_map / self.temp | |
| att_map = F.softmax(att_map, dim=-2) | |
| return att_map | |
| def _project(self, x, att_map): | |
| x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) | |
| x2 = self.proj_without_att(x) | |
| return x1 + x2 | |
| def _apply_BN(self, x): | |
| org_size = x.size() | |
| x = x.view(-1, org_size[-1]) | |
| x = self.bn(x) | |
| x = x.view(org_size) | |
| return x | |
| def _init_new_params(self, *size): | |
| out = nn.Parameter(torch.FloatTensor(*size)) | |
| nn.init.xavier_normal_(out) | |
| return out | |
| class HtrgGraphAttentionLayer(nn.Module): | |
| def __init__(self, in_dim, out_dim, **kwargs): | |
| super().__init__() | |
| self.proj_type1 = nn.Linear(in_dim, in_dim) | |
| self.proj_type2 = nn.Linear(in_dim, in_dim) | |
| # attention map | |
| self.att_proj = nn.Linear(in_dim, out_dim) | |
| self.att_projM = nn.Linear(in_dim, out_dim) | |
| self.att_weight11 = self._init_new_params(out_dim, 1) | |
| self.att_weight22 = self._init_new_params(out_dim, 1) | |
| self.att_weight12 = self._init_new_params(out_dim, 1) | |
| self.att_weightM = self._init_new_params(out_dim, 1) | |
| # project | |
| self.proj_with_att = nn.Linear(in_dim, out_dim) | |
| self.proj_without_att = nn.Linear(in_dim, out_dim) | |
| self.proj_with_attM = nn.Linear(in_dim, out_dim) | |
| self.proj_without_attM = nn.Linear(in_dim, out_dim) | |
| # batch norm | |
| self.bn = nn.BatchNorm1d(out_dim) | |
| # dropout for inputs | |
| self.input_drop = nn.Dropout(p=0.2) | |
| # activate | |
| self.act = nn.SELU(inplace=True) | |
| # temperature | |
| self.temp = 1. | |
| if "temperature" in kwargs: | |
| self.temp = kwargs["temperature"] | |
| def forward(self, x1, x2, master=None): | |
| ''' | |
| x1 :(#bs, #node, #dim) | |
| x2 :(#bs, #node, #dim) | |
| ''' | |
| num_type1 = x1.size(1) | |
| num_type2 = x2.size(1) | |
| x1 = self.proj_type1(x1) | |
| x2 = self.proj_type2(x2) | |
| x = torch.cat([x1, x2], dim=1) | |
| if master is None: | |
| master = torch.mean(x, dim=1, keepdim=True) | |
| # apply input dropout | |
| x = self.input_drop(x) | |
| # derive attention map | |
| att_map = self._derive_att_map(x, num_type1, num_type2) | |
| # directional edge for master node | |
| master = self._update_master(x, master) | |
| # projection | |
| x = self._project(x, att_map) | |
| # apply batch norm | |
| x = self._apply_BN(x) | |
| x = self.act(x) | |
| x1 = x.narrow(1, 0, num_type1) | |
| x2 = x.narrow(1, num_type1, num_type2) | |
| return x1, x2, master | |
| def _update_master(self, x, master): | |
| att_map = self._derive_att_map_master(x, master) | |
| master = self._project_master(x, master, att_map) | |
| return master | |
| def _pairwise_mul_nodes(self, x): | |
| ''' | |
| Calculates pairwise multiplication of nodes. | |
| - for attention map | |
| x :(#bs, #node, #dim) | |
| out_shape :(#bs, #node, #node, #dim) | |
| ''' | |
| nb_nodes = x.size(1) | |
| x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) | |
| x_mirror = x.transpose(1, 2) | |
| return x * x_mirror | |
| def _derive_att_map_master(self, x, master): | |
| ''' | |
| x :(#bs, #node, #dim) | |
| out_shape :(#bs, #node, #node, 1) | |
| ''' | |
| att_map = x * master | |
| att_map = torch.tanh(self.att_projM(att_map)) | |
| att_map = torch.matmul(att_map, self.att_weightM) | |
| # apply temperature | |
| att_map = att_map / self.temp | |
| att_map = F.softmax(att_map, dim=-2) | |
| return att_map | |
| def _derive_att_map(self, x, num_type1, num_type2): | |
| ''' | |
| x :(#bs, #node, #dim) | |
| out_shape :(#bs, #node, #node, 1) | |
| ''' | |
| att_map = self._pairwise_mul_nodes(x) | |
| # size: (#bs, #node, #node, #dim_out) | |
| att_map = torch.tanh(self.att_proj(att_map)) | |
| # size: (#bs, #node, #node, 1) | |
| att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1) | |
| att_board[:, :num_type1, :num_type1, :] = torch.matmul( | |
| att_map[:, :num_type1, :num_type1, :], self.att_weight11) | |
| att_board[:, num_type1:, num_type1:, :] = torch.matmul( | |
| att_map[:, num_type1:, num_type1:, :], self.att_weight22) | |
| att_board[:, :num_type1, num_type1:, :] = torch.matmul( | |
| att_map[:, :num_type1, num_type1:, :], self.att_weight12) | |
| att_board[:, num_type1:, :num_type1, :] = torch.matmul( | |
| att_map[:, num_type1:, :num_type1, :], self.att_weight12) | |
| att_map = att_board | |
| # apply temperature | |
| att_map = att_map / self.temp | |
| att_map = F.softmax(att_map, dim=-2) | |
| return att_map | |
| def _project(self, x, att_map): | |
| x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) | |
| x2 = self.proj_without_att(x) | |
| return x1 + x2 | |
| def _project_master(self, x, master, att_map): | |
| x1 = self.proj_with_attM(torch.matmul( | |
| att_map.squeeze(-1).unsqueeze(1), x)) | |
| x2 = self.proj_without_attM(master) | |
| return x1 + x2 | |
| def _apply_BN(self, x): | |
| org_size = x.size() | |
| x = x.view(-1, org_size[-1]) | |
| x = self.bn(x) | |
| x = x.view(org_size) | |
| return x | |
| def _init_new_params(self, *size): | |
| out = nn.Parameter(torch.FloatTensor(*size)) | |
| nn.init.xavier_normal_(out) | |
| return out | |
| class GraphPool(nn.Module): | |
| def __init__(self, k: float, in_dim: int, p: Union[float, int]): | |
| super().__init__() | |
| self.k = k | |
| self.sigmoid = nn.Sigmoid() | |
| self.proj = nn.Linear(in_dim, 1) | |
| self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() | |
| self.in_dim = in_dim | |
| def forward(self, h): | |
| Z = self.drop(h) | |
| weights = self.proj(Z) | |
| scores = self.sigmoid(weights) | |
| new_h = self.top_k_graph(scores, h, self.k) | |
| return new_h | |
| def top_k_graph(self, scores, h, k): | |
| _, n_nodes, n_feat = h.size() | |
| n_nodes = max(int(n_nodes * k), 1) | |
| _, idx = torch.topk(scores, n_nodes, dim=1) | |
| idx = idx.expand(-1, -1, n_feat) | |
| h = h * scores | |
| h = torch.gather(h, 1, idx) | |
| return h | |
| class CONV(nn.Module): | |
| def to_mel(hz): | |
| return 2595 * np.log10(1 + hz / 700) | |
| def to_hz(mel): | |
| return 700 * (10**(mel / 2595) - 1) | |
| def __init__(self, | |
| out_channels, | |
| kernel_size, | |
| sample_rate=16000, | |
| in_channels=1, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| bias=False, | |
| groups=1, | |
| mask=False): | |
| super().__init__() | |
| if in_channels != 1: | |
| msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels) | |
| raise ValueError(msg) | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.sample_rate = sample_rate | |
| # Forcing the filters to be odd (i.e, perfectly symmetrics) | |
| if kernel_size % 2 == 0: | |
| self.kernel_size = self.kernel_size + 1 | |
| self.stride = stride | |
| self.padding = padding | |
| self.dilation = dilation | |
| self.mask = mask | |
| if bias: | |
| raise ValueError('SincConv does not support bias.') | |
| if groups > 1: | |
| raise ValueError('SincConv does not support groups.') | |
| NFFT = 512 | |
| f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1) | |
| fmel = self.to_mel(f) | |
| fmelmax = np.max(fmel) | |
| fmelmin = np.min(fmel) | |
| filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1) | |
| filbandwidthsf = self.to_hz(filbandwidthsmel) | |
| self.mel = filbandwidthsf | |
| self.hsupp = torch.arange(-(self.kernel_size - 1) / 2, | |
| (self.kernel_size - 1) / 2 + 1) | |
| self.band_pass = torch.zeros(self.out_channels, self.kernel_size) | |
| for i in range(len(self.mel) - 1): | |
| fmin = self.mel[i] | |
| fmax = self.mel[i + 1] | |
| hHigh = (2*fmax/self.sample_rate) * \ | |
| np.sinc(2*fmax*self.hsupp/self.sample_rate) | |
| hLow = (2*fmin/self.sample_rate) * \ | |
| np.sinc(2*fmin*self.hsupp/self.sample_rate) | |
| hideal = hHigh - hLow | |
| self.band_pass[i, :] = Tensor(np.hamming( | |
| self.kernel_size)) * Tensor(hideal) | |
| def forward(self, x, mask=False): | |
| band_pass_filter = self.band_pass.clone().to(x.device) | |
| if mask: | |
| A = np.random.uniform(0, 20) | |
| A = int(A) | |
| A0 = random.randint(0, band_pass_filter.shape[0] - A) | |
| band_pass_filter[A0:A0 + A, :] = 0 | |
| else: | |
| band_pass_filter = band_pass_filter | |
| self.filters = (band_pass_filter).view(self.out_channels, 1, | |
| self.kernel_size) | |
| return F.conv1d(x, | |
| self.filters, | |
| stride=self.stride, | |
| padding=self.padding, | |
| dilation=self.dilation, | |
| bias=None, | |
| groups=1) | |
| class Residual_block(nn.Module): | |
| def __init__(self, nb_filts, first=False): | |
| super().__init__() | |
| self.first = first | |
| if not self.first: | |
| self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) | |
| self.conv1 = nn.Conv2d(in_channels=nb_filts[0], | |
| out_channels=nb_filts[1], | |
| kernel_size=(2, 3), | |
| padding=(1, 1), | |
| stride=1) | |
| self.selu = nn.SELU(inplace=True) | |
| self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) | |
| self.conv2 = nn.Conv2d(in_channels=nb_filts[1], | |
| out_channels=nb_filts[1], | |
| kernel_size=(2, 3), | |
| padding=(0, 1), | |
| stride=1) | |
| if nb_filts[0] != nb_filts[1]: | |
| self.downsample = True | |
| self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], | |
| out_channels=nb_filts[1], | |
| padding=(0, 1), | |
| kernel_size=(1, 3), | |
| stride=1) | |
| else: | |
| self.downsample = False | |
| self.mp = nn.MaxPool2d((1, 3)) | |
| def forward(self, x): | |
| identity = x | |
| if not self.first: | |
| out = self.bn1(x) | |
| out = self.selu(out) | |
| else: | |
| out = x | |
| out = self.conv1(x) | |
| out = self.bn2(out) | |
| out = self.selu(out) | |
| out = self.conv2(out) | |
| if self.downsample: | |
| identity = self.conv_downsample(identity) | |
| out += identity | |
| out = self.mp(out) | |
| return out | |
| class AASISTModel(nn.Module): | |
| def __init__(self, d_args): | |
| super().__init__() | |
| self.d_args = d_args | |
| filts = d_args["filts"] | |
| gat_dims = d_args["gat_dims"] | |
| pool_ratios = d_args["pool_ratios"] | |
| temperatures = d_args["temperatures"] | |
| self.conv_time = CONV(out_channels=filts[0], | |
| kernel_size=d_args["first_conv"], | |
| in_channels=1) | |
| self.first_bn = nn.BatchNorm2d(num_features=1) | |
| self.drop = nn.Dropout(0.5, inplace=True) | |
| self.drop_way = nn.Dropout(0.2, inplace=True) | |
| self.selu = nn.SELU(inplace=True) | |
| self.encoder = nn.Sequential( | |
| nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), | |
| nn.Sequential(Residual_block(nb_filts=filts[2])), | |
| nn.Sequential(Residual_block(nb_filts=filts[3])), | |
| nn.Sequential(Residual_block(nb_filts=filts[4])), | |
| nn.Sequential(Residual_block(nb_filts=filts[4])), | |
| nn.Sequential(Residual_block(nb_filts=filts[4]))) | |
| self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1])) | |
| self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) | |
| self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) | |
| self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], | |
| gat_dims[0], | |
| temperature=temperatures[0]) | |
| self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], | |
| gat_dims[0], | |
| temperature=temperatures[1]) | |
| self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( | |
| gat_dims[0], gat_dims[1], temperature=temperatures[2]) | |
| self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( | |
| gat_dims[1], gat_dims[1], temperature=temperatures[2]) | |
| self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( | |
| gat_dims[0], gat_dims[1], temperature=temperatures[2]) | |
| self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( | |
| gat_dims[1], gat_dims[1], temperature=temperatures[2]) | |
| self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) | |
| self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) | |
| self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
| self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
| self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
| self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
| self.out_layer = nn.Linear(5 * gat_dims[1], 2) | |
| def forward(self, x, Freq_aug=False): | |
| x = x.unsqueeze(1) | |
| x = self.conv_time(x, mask=Freq_aug) | |
| x = x.unsqueeze(dim=1) | |
| x = F.max_pool2d(torch.abs(x), (3, 3)) | |
| x = self.first_bn(x) | |
| x = self.selu(x) | |
| e = self.encoder(x) | |
| e_S, _ = torch.max(torch.abs(e), dim=3) | |
| e_S = e_S.transpose(1, 2) + self.pos_S | |
| gat_S = self.GAT_layer_S(e_S) | |
| out_S = self.pool_S(gat_S) | |
| e_T, _ = torch.max(torch.abs(e), dim=2) | |
| e_T = e_T.transpose(1, 2) | |
| gat_T = self.GAT_layer_T(e_T) | |
| out_T = self.pool_T(gat_T) | |
| master1 = self.master1.expand(x.size(0), -1, -1) | |
| master2 = self.master2.expand(x.size(0), -1, -1) | |
| out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( | |
| out_T, out_S, master=self.master1) | |
| out_S1 = self.pool_hS1(out_S1) | |
| out_T1 = self.pool_hT1(out_T1) | |
| out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( | |
| out_T1, out_S1, master=master1) | |
| out_T1 = out_T1 + out_T_aug | |
| out_S1 = out_S1 + out_S_aug | |
| master1 = master1 + master_aug | |
| out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( | |
| out_T, out_S, master=self.master2) | |
| out_S2 = self.pool_hS2(out_S2) | |
| out_T2 = self.pool_hT2(out_T2) | |
| out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( | |
| out_T2, out_S2, master=master2) | |
| out_T2 = out_T2 + out_T_aug | |
| out_S2 = out_S2 + out_S_aug | |
| master2 = master2 + master_aug | |
| out_T1 = self.drop_way(out_T1) | |
| out_T2 = self.drop_way(out_T2) | |
| out_S1 = self.drop_way(out_S1) | |
| out_S2 = self.drop_way(out_S2) | |
| master1 = self.drop_way(master1) | |
| master2 = self.drop_way(master2) | |
| out_T = torch.max(out_T1, out_T2) | |
| out_S = torch.max(out_S1, out_S2) | |
| master = torch.max(master1, master2) | |
| T_max, _ = torch.max(torch.abs(out_T), dim=1) | |
| T_avg = torch.mean(out_T, dim=1) | |
| S_max, _ = torch.max(torch.abs(out_S), dim=1) | |
| S_avg = torch.mean(out_S, dim=1) | |
| last_hidden = torch.cat( | |
| [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) | |
| last_hidden = self.drop(last_hidden) | |
| output = self.out_layer(last_hidden) | |
| return last_hidden, output | |
| class AASISTDetector(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| d_args = { | |
| "nb_samp": 64600, | |
| "first_conv": 128, | |
| "in_channels": 1, | |
| "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]], | |
| "gat_dims": [64, 32], | |
| "pool_ratios": [0.5, 0.7, 0.5, 0.5], | |
| "temperatures": [2.0, 2.0, 100.0] | |
| } | |
| self.model = AASISTModel(d_args) | |
| # Override out_layer if not strictly 2 classes. | |
| if num_classes != 2: | |
| self.model.out_layer = nn.Linear(5 * d_args["gat_dims"][1], num_classes) | |
| def forward(self, x): | |
| # x is (B, 1, T) or (B, T) | |
| if x.dim() == 3: | |
| x = x.squeeze(1) # Convert to (B, T) | |
| _, out = self.model(x) | |
| return out | |
| # ============================================================ | |
| # 3. CQCC Baseline Detector (Acoustic Feature Baseline) | |
| # ============================================================ | |
| class CQCCBaselineDetector(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| # Input shape expected: (B, 1, 20, T) | |
| self.features = nn.Sequential( | |
| nn.Conv2d(1, 16, 3, padding=1), | |
| nn.BatchNorm2d(16), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(16, 32, 3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.AdaptiveAvgPool2d(1) | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(0.3), | |
| nn.Linear(64, num_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = x.flatten(1) | |
| return self.classifier(x) | |
| # ============================================================ | |
| # 4. Custom Fusional Wav2Vec2 + CQCC with Cross-Attention + Graph | |
| # ============================================================ | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, dim, max_len=6000): | |
| super().__init__() | |
| self.pos_embed = nn.Parameter(torch.randn(1, max_len, dim)) | |
| def forward(self, x): | |
| return x + self.pos_embed[:, :x.size(1)] | |
| class BidirectionalCrossAttention(nn.Module): | |
| def __init__(self, dim, num_heads=4): | |
| super().__init__() | |
| self.attn1 = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=0.2) | |
| self.attn2 = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=0.2) | |
| self.norm_q = nn.LayerNorm(dim) | |
| self.norm_kv = nn.LayerNorm(dim) | |
| def forward(self, x1, x2): | |
| # x1 attends to x2 | |
| q1 = self.norm_q(x1) | |
| k2 = self.norm_kv(x2) | |
| v2 = k2 | |
| out1, _ = self.attn1(q1, k2, v2) | |
| # x2 attends to x1 | |
| q2 = self.norm_q(x2) | |
| k1 = self.norm_kv(x1) | |
| v1 = k1 | |
| out2, _ = self.attn2(q2, k1, v1) | |
| return out1, out2 | |
| def align_sequences(x, target_len): | |
| """Linear interpolation to match sequence lengths""" | |
| x = x.transpose(1, 2) | |
| x = F.interpolate(x, size=target_len, mode='linear', align_corners=False) | |
| return x.transpose(1, 2) | |
| class ImprovedWav2Vec2CQCCDetector(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| # Wav2Vec2 | |
| self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| # Freeze the Wav2Vec2 layer so it acts purely as a feature extractor | |
| for param in self.wav2vec.parameters(): | |
| param.requires_grad = False | |
| dim = self.wav2vec.config.hidden_size | |
| # CQCC encoder | |
| self.cqcc_conv = nn.Sequential( | |
| nn.Conv1d(20, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(128), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Conv1d(128, dim, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(dim), | |
| nn.GELU() | |
| ) | |
| # Positional Encoding | |
| self.pos_enc = PositionalEncoding(dim) | |
| # Bidirectional Cross Attention | |
| self.cross_attn = BidirectionalCrossAttention(dim) | |
| # True Graph Transformer Backend (using GAT blocks from AASIST) | |
| self.graph_layers = nn.ModuleList([ | |
| GraphBlock(dim) for _ in range(3) | |
| ]) | |
| # Classifier | |
| self.classifier = nn.Sequential( | |
| nn.Linear(dim, 128), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, wav, cqcc): | |
| if wav.dim() == 3: | |
| wav = wav.squeeze(1) | |
| # Wav2Vec2 features | |
| w2v = self.wav2vec(wav).last_hidden_state # (B, T_w, D) | |
| # CQCC features | |
| if cqcc.dim() == 4: | |
| cqcc = cqcc.squeeze(1) | |
| cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2) # (B, T_c, D) | |
| # Align lengths | |
| cqcc_feat = align_sequences(cqcc_feat, w2v.size(1)) | |
| # Add positional encoding | |
| w2v = self.pos_enc(w2v) | |
| cqcc_feat = self.pos_enc(cqcc_feat) | |
| # Cross attention (bidirectional) | |
| f1, f2 = self.cross_attn(cqcc_feat, w2v) | |
| fused = f1 + f2 | |
| # Graph Transformer processing on node sequences | |
| x = fused | |
| for layer in self.graph_layers: | |
| x = layer(x) | |
| # Global average pooling on the nodes | |
| pooled = x.mean(dim=1) | |
| return self.classifier(pooled) | |
| # ============================================================ | |
| # 5. Ablation Models | |
| # ============================================================ | |
| class AblationWav2Vec2GraphDetector(nn.Module): | |
| """Ablation 1: Wav2Vec2 only + Graph Backend (No CQCC, No Cross-Attention)""" | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| for param in self.wav2vec.parameters(): | |
| param.requires_grad = False | |
| dim = self.wav2vec.config.hidden_size | |
| self.pos_enc = PositionalEncoding(dim) | |
| self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)]) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, wav, cqcc=None): # Accept both but ignore CQCC | |
| if wav.dim() == 3: | |
| wav = wav.squeeze(1) | |
| w2v = self.wav2vec(wav).last_hidden_state | |
| w2v = self.pos_enc(w2v) | |
| x = w2v | |
| for layer in self.graph_layers: | |
| x = layer(x) | |
| pooled = x.mean(dim=1) | |
| return self.classifier(pooled) | |
| class AblationCQCCGraphDetector(nn.Module): | |
| """Ablation 2: CQCC only + Graph Backend (No Wav2Vec2, No Cross-Attention)""" | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| dim = 768 # Match Wav2Vec2 hidden size for fair comparison | |
| self.cqcc_conv = nn.Sequential( | |
| nn.Conv1d(20, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(128), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Conv1d(128, dim, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(dim), | |
| nn.GELU() | |
| ) | |
| self.pos_enc = PositionalEncoding(dim) | |
| self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)]) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, cqcc): | |
| if cqcc.dim() == 4: | |
| cqcc = cqcc.squeeze(1) | |
| cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2) | |
| cqcc_feat = self.pos_enc(cqcc_feat) | |
| x = cqcc_feat | |
| for layer in self.graph_layers: | |
| x = layer(x) | |
| pooled = x.mean(dim=1) | |
| return self.classifier(pooled) | |
| class AblationConcatGraphDetector(nn.Module): | |
| """Ablation 3: Wav2Vec2 + CQCC + Simple Concat Fusion + Graph Backend (No Cross-Attention)""" | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| for param in self.wav2vec.parameters(): | |
| param.requires_grad = False | |
| dim = self.wav2vec.config.hidden_size | |
| self.cqcc_conv = nn.Sequential( | |
| nn.Conv1d(20, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(128), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Conv1d(128, dim, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(dim), | |
| nn.GELU() | |
| ) | |
| self.fusion_proj = nn.Linear(dim * 2, dim) # Project concatenated features back to dim | |
| self.pos_enc = PositionalEncoding(dim) | |
| self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)]) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, wav, cqcc): | |
| if wav.dim() == 3: | |
| wav = wav.squeeze(1) | |
| w2v = self.wav2vec(wav).last_hidden_state | |
| if cqcc.dim() == 4: | |
| cqcc = cqcc.squeeze(1) | |
| cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2) | |
| cqcc_feat = align_sequences(cqcc_feat, w2v.size(1)) | |
| # Simple concat over feature dimension instead of cross-attention | |
| fused = torch.cat([w2v, cqcc_feat], dim=-1) | |
| fused = self.fusion_proj(fused) | |
| fused = self.pos_enc(fused) | |
| x = fused | |
| for layer in self.graph_layers: | |
| x = layer(x) | |
| pooled = x.mean(dim=1) | |
| return self.classifier(pooled) | |
| class AblationCrossAttnLinearDetector(nn.Module): | |
| """Ablation 4: Wav2Vec2 + CQCC + Cross-Attention + Linear Backend (No Graph Transformer)""" | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| for param in self.wav2vec.parameters(): | |
| param.requires_grad = False | |
| dim = self.wav2vec.config.hidden_size | |
| self.cqcc_conv = nn.Sequential( | |
| nn.Conv1d(20, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(128), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Conv1d(128, dim, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(dim), | |
| nn.GELU() | |
| ) | |
| self.pos_enc = PositionalEncoding(dim) | |
| self.cross_attn = BidirectionalCrossAttention(dim) | |
| # Richer MLP classifier since graph is missing | |
| self.classifier = nn.Sequential( | |
| nn.Linear(dim, 256), | |
| nn.GELU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 128), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, wav, cqcc): | |
| if wav.dim() == 3: | |
| wav = wav.squeeze(1) | |
| w2v = self.wav2vec(wav).last_hidden_state | |
| if cqcc.dim() == 4: | |
| cqcc = cqcc.squeeze(1) | |
| cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2) | |
| cqcc_feat = align_sequences(cqcc_feat, w2v.size(1)) | |
| w2v = self.pos_enc(w2v) | |
| cqcc_feat = self.pos_enc(cqcc_feat) | |
| f1, f2 = self.cross_attn(cqcc_feat, w2v) | |
| fused = f1 + f2 | |
| # No graph layer, straight to global average pooling | |
| pooled = fused.mean(dim=1) | |
| return self.classifier(pooled) | |