| """ |
| ECAPA-TDNN Teacher and Student models for dynamic distillation. |
| |
| Architecture (Desplanques et al., Interspeech 2020): |
| TDNN (input) → 3x SE-Res2NetBlock → Cat+BN → AttStatPool → FC → L2-emb |
| |
| Teacher : 512 channels, emb_dim=192 (~14M params) |
| Student : 256 channels, emb_dim=128 (~3.5M params, same architecture) |
| |
| The student is trained via dynamic knowledge distillation (see distillation.py). |
| Input: log-Mel filterbank (80 bins), shape (B, 80, T). |
| """ |
| from __future__ import annotations |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| class TDNNBlock(nn.Module): |
| """Standard Time-Delay Neural Network layer.""" |
| def __init__(self, in_ch: int, out_ch: int, kernel: int = 5, |
| dilation: int = 1, groups: int = 1): |
| super().__init__() |
| pad = (kernel - 1) // 2 * dilation |
| self.conv = nn.Conv1d(in_ch, out_ch, kernel, dilation=dilation, |
| padding=pad, groups=groups) |
| self.bn = nn.BatchNorm1d(out_ch) |
| self.act = nn.ReLU() |
|
|
| def forward(self, x): |
| return self.act(self.bn(self.conv(x))) |
|
|
|
|
| class Res2NetConv(nn.Module): |
| """Res2Net multi-scale convolution sub-module.""" |
| def __init__(self, channels: int, scale: int = 8, kernel: int = 3, |
| dilation: int = 2): |
| super().__init__() |
| assert channels % scale == 0 |
| self.scale = scale |
| w = channels // scale |
| pad = (kernel - 1) // 2 * dilation |
| self.convs = nn.ModuleList([ |
| nn.Sequential( |
| nn.Conv1d(w, w, kernel, dilation=dilation, padding=pad), |
| nn.BatchNorm1d(w), |
| nn.ReLU(), |
| ) |
| for _ in range(scale - 1) |
| ]) |
|
|
| def forward(self, x): |
| chunks = torch.chunk(x, self.scale, dim=1) |
| out = [chunks[0]] |
| for i, conv in enumerate(self.convs): |
| y = chunks[i + 1] if i == 0 else chunks[i + 1] + out[-1] |
| out.append(conv(y)) |
| return torch.cat(out, dim=1) |
|
|
|
|
| class SEBlock(nn.Module): |
| """Squeeze-and-Excitation channel attention.""" |
| def __init__(self, channels: int, bottleneck: int = 128): |
| super().__init__() |
| self.fc = nn.Sequential( |
| nn.Linear(channels, bottleneck), |
| nn.ReLU(), |
| nn.Linear(bottleneck, channels), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward(self, x): |
| s = x.mean(dim=2) |
| s = self.fc(s).unsqueeze(2) |
| return x * s |
|
|
|
|
| class SERes2NetBlock(nn.Module): |
| """ |
| SE-Res2Net block — the core repeating block of ECAPA-TDNN. |
| in_ch = out_ch (residual connection) |
| """ |
| def __init__(self, channels: int, scale: int = 8, se_bottleneck: int = 128, |
| kernel: int = 3, dilation: int = 2): |
| super().__init__() |
| self.tdnn1 = TDNNBlock(channels, channels, 1) |
| self.res2net = Res2NetConv(channels, scale, kernel, dilation) |
| self.tdnn2 = TDNNBlock(channels, channels, 1) |
| self.se = SEBlock(channels, se_bottleneck) |
| self.bn = nn.BatchNorm1d(channels) |
|
|
| def forward(self, x): |
| residual = x |
| x = self.tdnn1(x) |
| x = self.res2net(x) |
| x = self.tdnn2(x) |
| x = self.se(x) |
| return F.relu(self.bn(x + residual)) |
|
|
|
|
| class AttentiveStatisticsPooling(nn.Module): |
| """ |
| Computes attention-weighted mean + std across the time axis. |
| Input: (B, C, T) |
| Output: (B, 2C) |
| """ |
| def __init__(self, channels: int, attention_dim: int = 128): |
| super().__init__() |
| self.attn = nn.Sequential( |
| nn.Conv1d(channels * 3, attention_dim, 1), |
| nn.Tanh(), |
| nn.Conv1d(attention_dim, channels, 1), |
| nn.Softmax(dim=2), |
| ) |
|
|
| def forward(self, x): |
| |
| mu = x.mean(dim=2, keepdim=True).expand_as(x) |
| sg = x.std(dim=2, keepdim=True).expand_as(x) |
| ctx = torch.cat([x, mu, sg], dim=1) |
| alpha = self.attn(ctx) |
| mean = (alpha * x).sum(dim=2) |
| std = (alpha * (x - mean.unsqueeze(2)).pow(2)).sum(dim=2).clamp(1e-9).sqrt() |
| return torch.cat([mean, std], dim=1) |
|
|
|
|
| |
| |
| |
|
|
| class ECAPA_TDNN(nn.Module): |
| """ |
| Parameters |
| ---------- |
| in_channels : input feature dim (default 80 log-Mel bins) |
| channels : main TDNN channel width (512 = teacher, 256 = student) |
| emb_dim : speaker embedding dimension (192 = teacher, 128 = student) |
| n_classes : number of output classes (for speaker-id softmax head); |
| set 0 to skip the classification head |
| scale : Res2Net scale parameter |
| se_bottleneck: SE reduction bottleneck size |
| """ |
| def __init__(self, |
| in_channels: int = 80, |
| channels: int = 512, |
| emb_dim: int = 192, |
| n_classes: int = 0, |
| scale: int = 8, |
| se_bottleneck: int = 128): |
| super().__init__() |
| self.channels = channels |
| self.emb_dim = emb_dim |
|
|
| |
| self.input_tdnn = TDNNBlock(in_channels, channels, kernel=5) |
|
|
| |
| self.block1 = SERes2NetBlock(channels, scale, se_bottleneck, dilation=2) |
| self.block2 = SERes2NetBlock(channels, scale, se_bottleneck, dilation=3) |
| self.block3 = SERes2NetBlock(channels, scale, se_bottleneck, dilation=4) |
|
|
| |
| self.cat_bn = nn.BatchNorm1d(channels * 3) |
| self.cat_tdnn = nn.Conv1d(channels * 3, channels * 3, 1) |
|
|
| |
| self.pool = AttentiveStatisticsPooling(channels * 3, |
| attention_dim=max(64, channels // 2)) |
| self.bn_pool = nn.BatchNorm1d(channels * 6) |
| self.fc_emb = nn.Linear(channels * 6, emb_dim) |
| self.bn_emb = nn.BatchNorm1d(emb_dim) |
|
|
| |
| self.classifier = nn.Linear(emb_dim, n_classes) if n_classes > 0 else None |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv1d): |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.BatchNorm1d): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, x, return_intermediates: bool = False): |
| """ |
| x : (B, 80, T) log-Mel filterbank |
| |
| Returns |
| ------- |
| emb : (B, emb_dim) L2-normalised speaker embedding |
| logits : (B, n_classes) or None |
| intermediates (optional) : list of block output tensors for distillation |
| """ |
| x = self.input_tdnn(x) |
|
|
| h1 = self.block1(x) |
| h2 = self.block2(h1) |
| h3 = self.block3(h2) |
|
|
| cat = torch.cat([h1, h2, h3], dim=1) |
| cat = F.relu(self.cat_bn(self.cat_tdnn(cat))) |
|
|
| pooled = self.pool(cat) |
| pooled = self.bn_pool(pooled) |
|
|
| emb = self.bn_emb(self.fc_emb(pooled)) |
| emb = F.normalize(emb, p=2, dim=1) |
|
|
| logits = self.classifier(emb) if self.classifier is not None else None |
|
|
| if return_intermediates: |
| return emb, logits, [h1, h2, h3] |
| return emb, logits |
|
|
|
|
| |
| |
| |
|
|
| def build_teacher(n_classes: int = 20) -> ECAPA_TDNN: |
| """Full-size teacher (512 ch, 192-dim emb).""" |
| return ECAPA_TDNN(channels=512, emb_dim=192, n_classes=n_classes) |
|
|
|
|
| def build_student(n_classes: int = 20) -> ECAPA_TDNN: |
| """Half-size student (256 ch, 128-dim emb) — same architecture.""" |
| return ECAPA_TDNN(channels=256, emb_dim=128, n_classes=n_classes) |
|
|
|
|
| def count_params(model: nn.Module) -> str: |
| n = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| return f"{n/1e6:.2f}M" |
|
|
|
|
| |
| |
| |
|
|
| class LogMelFrontend(nn.Module): |
| """ |
| Differentiable log-Mel spectrogram on GPU. |
| Uses torchaudio.transforms. |
| """ |
| def __init__(self, sample_rate: int = 16_000, |
| n_fft: int = 512, |
| win_length: int = 400, |
| hop_length: int = 160, |
| n_mels: int = 80): |
| super().__init__() |
| import torchaudio.transforms as T |
| self.mel = T.MelSpectrogram( |
| sample_rate=sample_rate, |
| n_fft=n_fft, |
| win_length=win_length, |
| hop_length=hop_length, |
| n_mels=n_mels, |
| f_min=20, f_max=7600, |
| power=2.0, |
| ) |
| self.db = T.AmplitudeToDB(stype="power", top_db=80) |
|
|
| def forward(self, waveform): |
| mel = self.mel(waveform) |
| return self.db(mel) |
|
|