Spaces:
Sleeping
Sleeping
| """ | |
| Classification Heads — Pooling strategies and MLP classifier for deepfake detection. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class AttentiveStatsPooling(nn.Module): | |
| """ | |
| Attentive Statistics Pooling. | |
| Learns which frames are most important, then computes weighted mean + std. | |
| Used in ECAPA-TDNN and top speaker verification systems. | |
| """ | |
| def __init__(self, hidden_size: int, attention_dim: int = 128): | |
| super().__init__() | |
| self.attention = nn.Sequential( | |
| nn.Linear(hidden_size, attention_dim), | |
| nn.Tanh(), | |
| nn.Linear(attention_dim, 1), | |
| ) | |
| self.output_size = hidden_size * 2 # mean + std concatenated | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: (batch, time, hidden_size) | |
| mask: optional (batch, time) boolean mask | |
| Returns: | |
| (batch, hidden_size * 2) — weighted mean and std | |
| """ | |
| # Compute attention weights | |
| attn_weights = self.attention(x).squeeze(-1) # (batch, time) | |
| if mask is not None: | |
| attn_weights = attn_weights.masked_fill(~mask, float("-inf")) | |
| attn_weights = F.softmax(attn_weights, dim=-1).unsqueeze(-1) # (batch, time, 1) | |
| # Weighted mean | |
| mean = torch.sum(x * attn_weights, dim=1) # (batch, hidden) | |
| # Weighted std | |
| var = torch.sum(attn_weights * (x - mean.unsqueeze(1)) ** 2, dim=1) | |
| std = torch.sqrt(var.clamp(min=1e-6)) | |
| return torch.cat([mean, std], dim=-1) # (batch, hidden*2) | |
| class MultiHeadAttentionPooling(nn.Module): | |
| """ | |
| Multi-Head Attention Pooling. | |
| Applies multi-head self-attention then pools via learned query vector. | |
| """ | |
| def __init__(self, hidden_size: int, num_heads: int = 4): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.query = nn.Parameter(torch.randn(1, 1, hidden_size)) | |
| self.mha = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True) | |
| self.output_size = hidden_size | |
| nn.init.xavier_uniform_(self.query) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: (batch, time, hidden_size) | |
| Returns: | |
| (batch, hidden_size) | |
| """ | |
| batch_size = x.size(0) | |
| query = self.query.expand(batch_size, -1, -1) # (batch, 1, hidden) | |
| out, _ = self.mha(query, x, x) # (batch, 1, hidden) | |
| return out.squeeze(1) # (batch, hidden) | |
| class MeanPooling(nn.Module): | |
| """Simple mean pooling over the time axis.""" | |
| def __init__(self, hidden_size: int): | |
| super().__init__() | |
| self.output_size = hidden_size | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: | |
| if mask is not None: | |
| x = x * mask.unsqueeze(-1).float() | |
| return x.sum(dim=1) / mask.sum(dim=1, keepdim=True).float() | |
| return x.mean(dim=1) | |
| class DeepfakeClassifier(nn.Module): | |
| """ | |
| Full classification model = Backbone + Pooling + MLP Head. | |
| """ | |
| def __init__(self, backbone: nn.Module, hidden_size: int, | |
| num_labels: int = 2, classifier_hidden: int = 256, | |
| dropout: float = 0.3, pooling_type: str = "attentive_stats"): | |
| super().__init__() | |
| self.backbone = backbone | |
| # Select pooling strategy | |
| if pooling_type == "attentive_stats": | |
| self.pooling = AttentiveStatsPooling(hidden_size) | |
| elif pooling_type == "multi_head": | |
| self.pooling = MultiHeadAttentionPooling(hidden_size) | |
| elif pooling_type == "mean": | |
| self.pooling = MeanPooling(hidden_size) | |
| else: | |
| raise ValueError(f"Unknown pooling: {pooling_type}") | |
| pool_output_size = self.pooling.output_size | |
| # MLP classification head with batch norm | |
| self.classifier = nn.Sequential( | |
| nn.Linear(pool_output_size, classifier_hidden), | |
| nn.BatchNorm1d(classifier_hidden), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(classifier_hidden, classifier_hidden // 2), | |
| nn.BatchNorm1d(classifier_hidden // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout / 2), | |
| nn.Linear(classifier_hidden // 2, num_labels), | |
| ) | |
| # Initialize weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| for m in self.classifier: | |
| if isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight, nonlinearity="relu") | |
| nn.init.zeros_(m.bias) | |
| def forward(self, input_values: torch.Tensor, | |
| attention_mask: torch.Tensor = None) -> torch.Tensor: | |
| """ | |
| Args: | |
| input_values: (batch, time) raw waveform | |
| attention_mask: (batch, time) attention mask | |
| Returns: | |
| logits: (batch, num_labels) | |
| """ | |
| # Extract features from backbone | |
| outputs = self.backbone(input_values, attention_mask=attention_mask) | |
| hidden_states = outputs.last_hidden_state # (batch, seq_len, hidden) | |
| # Pool across time | |
| pooled = self.pooling(hidden_states) # (batch, pool_dim) | |
| # Classify | |
| logits = self.classifier(pooled) # (batch, num_labels) | |
| return logits | |
| def extract_embeddings(self, input_values: torch.Tensor) -> torch.Tensor: | |
| """Extract embeddings (before classification head) for analysis.""" | |
| outputs = self.backbone(input_values) | |
| hidden_states = outputs.last_hidden_state | |
| return self.pooling(hidden_states) | |