Spaces:
Running
Running
| # Adapted from: https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ecapa_tdnn.py | |
| import torch | |
| import torch.nn as nn | |
| from .convnext import ConvNextBackbone | |
| class AttentiveStatsPool(nn.Module): | |
| def __init__(self, input_channels: int, output_channels: int, attention_channels: int = 128): | |
| super().__init__() | |
| self.attn = nn.Sequential( | |
| nn.Conv1d(input_channels, attention_channels, kernel_size=1), | |
| nn.Tanh(), | |
| nn.Conv1d(attention_channels, input_channels, kernel_size=1), | |
| nn.Softmax(dim=2), | |
| ) | |
| self.proj = nn.Linear(input_channels * 2, output_channels) | |
| self.norm = nn.LayerNorm(output_channels) | |
| def forward(self, x): | |
| alpha = self.attn(x) | |
| mean = torch.sum(alpha * x, dim=2) | |
| residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 | |
| std = torch.sqrt(residuals.clamp(min=1e-4, max=1e4)) | |
| x = torch.cat([mean, std], dim=1) | |
| return self.norm(self.proj(x)) | |
| class GlobalEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| input_channels: int, | |
| output_channels: int, | |
| dim: int, | |
| intermediate_dim: int, | |
| num_layers: int, | |
| skip_embed: bool = False, | |
| attention_channels: int = 128, | |
| use_attn_pool: bool = True, | |
| ): | |
| super().__init__() | |
| self.backbone = ConvNextBackbone( | |
| input_channels=input_channels, | |
| dim=dim, | |
| intermediate_dim=intermediate_dim, | |
| num_layers=num_layers, | |
| skip_embed=skip_embed, | |
| ) | |
| if use_attn_pool: | |
| self.pooling = AttentiveStatsPool( | |
| input_channels=dim, output_channels=output_channels, attention_channels=attention_channels | |
| ) | |
| else: | |
| self.pooling = nn.Sequential( | |
| nn.AdaptiveAvgPool1d(1), | |
| nn.Flatten(1), | |
| nn.Linear(dim, output_channels), | |
| nn.LayerNorm(output_channels), | |
| ) | |
| self.output_channels = output_channels | |
| def output_dim(self): | |
| return self.output_channels | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| # (B, T, C) -> (B, C, T) | |
| features = features.transpose(1, 2) | |
| return self.pooling(features) # (B, C_out) | |