""" ECAPA-TDNN Teacher and Student models for dynamic distillation. Architecture (Desplanques et al., Interspeech 2020): TDNN (input) → 3x SE-Res2NetBlock → Cat+BN → AttStatPool → FC → L2-emb Teacher : 512 channels, emb_dim=192 (~14M params) Student : 256 channels, emb_dim=128 (~3.5M params, same architecture) The student is trained via dynamic knowledge distillation (see distillation.py). Input: log-Mel filterbank (80 bins), shape (B, 80, T). """ from __future__ import annotations import math import torch import torch.nn as nn import torch.nn.functional as F # ───────────────────────────────────────────────────────────────────────────── # Building blocks # ───────────────────────────────────────────────────────────────────────────── class TDNNBlock(nn.Module): """Standard Time-Delay Neural Network layer.""" def __init__(self, in_ch: int, out_ch: int, kernel: int = 5, dilation: int = 1, groups: int = 1): super().__init__() pad = (kernel - 1) // 2 * dilation self.conv = nn.Conv1d(in_ch, out_ch, kernel, dilation=dilation, padding=pad, groups=groups) self.bn = nn.BatchNorm1d(out_ch) self.act = nn.ReLU() def forward(self, x): # (B, C, T) → (B, out_ch, T) return self.act(self.bn(self.conv(x))) class Res2NetConv(nn.Module): """Res2Net multi-scale convolution sub-module.""" def __init__(self, channels: int, scale: int = 8, kernel: int = 3, dilation: int = 2): super().__init__() assert channels % scale == 0 self.scale = scale w = channels // scale pad = (kernel - 1) // 2 * dilation self.convs = nn.ModuleList([ nn.Sequential( nn.Conv1d(w, w, kernel, dilation=dilation, padding=pad), nn.BatchNorm1d(w), nn.ReLU(), ) for _ in range(scale - 1) ]) def forward(self, x): # (B, C, T) chunks = torch.chunk(x, self.scale, dim=1) out = [chunks[0]] for i, conv in enumerate(self.convs): y = chunks[i + 1] if i == 0 else chunks[i + 1] + out[-1] out.append(conv(y)) return torch.cat(out, dim=1) class SEBlock(nn.Module): """Squeeze-and-Excitation channel attention.""" def __init__(self, channels: int, bottleneck: int = 128): super().__init__() self.fc = nn.Sequential( nn.Linear(channels, bottleneck), nn.ReLU(), nn.Linear(bottleneck, channels), nn.Sigmoid(), ) def forward(self, x): # (B, C, T) s = x.mean(dim=2) # (B, C) global avg s = self.fc(s).unsqueeze(2) # (B, C, 1) return x * s class SERes2NetBlock(nn.Module): """ SE-Res2Net block — the core repeating block of ECAPA-TDNN. in_ch = out_ch (residual connection) """ def __init__(self, channels: int, scale: int = 8, se_bottleneck: int = 128, kernel: int = 3, dilation: int = 2): super().__init__() self.tdnn1 = TDNNBlock(channels, channels, 1) self.res2net = Res2NetConv(channels, scale, kernel, dilation) self.tdnn2 = TDNNBlock(channels, channels, 1) self.se = SEBlock(channels, se_bottleneck) self.bn = nn.BatchNorm1d(channels) def forward(self, x): # (B, C, T) residual = x x = self.tdnn1(x) x = self.res2net(x) x = self.tdnn2(x) x = self.se(x) return F.relu(self.bn(x + residual)) class AttentiveStatisticsPooling(nn.Module): """ Computes attention-weighted mean + std across the time axis. Input: (B, C, T) Output: (B, 2C) """ def __init__(self, channels: int, attention_dim: int = 128): super().__init__() self.attn = nn.Sequential( nn.Conv1d(channels * 3, attention_dim, 1), nn.Tanh(), nn.Conv1d(attention_dim, channels, 1), nn.Softmax(dim=2), ) def forward(self, x): # (B, C, T) # Context vector (global statistics for attention query) mu = x.mean(dim=2, keepdim=True).expand_as(x) sg = x.std(dim=2, keepdim=True).expand_as(x) ctx = torch.cat([x, mu, sg], dim=1) # (B, 3C, T) alpha = self.attn(ctx) # (B, C, T) mean = (alpha * x).sum(dim=2) # (B, C) std = (alpha * (x - mean.unsqueeze(2)).pow(2)).sum(dim=2).clamp(1e-9).sqrt() return torch.cat([mean, std], dim=1) # (B, 2C) # ───────────────────────────────────────────────────────────────────────────── # Full ECAPA-TDNN # ───────────────────────────────────────────────────────────────────────────── class ECAPA_TDNN(nn.Module): """ Parameters ---------- in_channels : input feature dim (default 80 log-Mel bins) channels : main TDNN channel width (512 = teacher, 256 = student) emb_dim : speaker embedding dimension (192 = teacher, 128 = student) n_classes : number of output classes (for speaker-id softmax head); set 0 to skip the classification head scale : Res2Net scale parameter se_bottleneck: SE reduction bottleneck size """ def __init__(self, in_channels: int = 80, channels: int = 512, emb_dim: int = 192, n_classes: int = 0, scale: int = 8, se_bottleneck: int = 128): super().__init__() self.channels = channels self.emb_dim = emb_dim # Input TDNN self.input_tdnn = TDNNBlock(in_channels, channels, kernel=5) # Three SE-Res2Net blocks with increasing dilation self.block1 = SERes2NetBlock(channels, scale, se_bottleneck, dilation=2) self.block2 = SERes2NetBlock(channels, scale, se_bottleneck, dilation=3) self.block3 = SERes2NetBlock(channels, scale, se_bottleneck, dilation=4) # Aggregation: cat all three block outputs → channel * 3 self.cat_bn = nn.BatchNorm1d(channels * 3) self.cat_tdnn = nn.Conv1d(channels * 3, channels * 3, 1) # Attentive statistics pooling → 2 * channel * 3 → emb self.pool = AttentiveStatisticsPooling(channels * 3, attention_dim=max(64, channels // 2)) self.bn_pool = nn.BatchNorm1d(channels * 6) self.fc_emb = nn.Linear(channels * 6, emb_dim) self.bn_emb = nn.BatchNorm1d(emb_dim) # Optional classification head (for speaker-id training) self.classifier = nn.Linear(emb_dim, n_classes) if n_classes > 0 else None self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm1d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x, return_intermediates: bool = False): """ x : (B, 80, T) log-Mel filterbank Returns ------- emb : (B, emb_dim) L2-normalised speaker embedding logits : (B, n_classes) or None intermediates (optional) : list of block output tensors for distillation """ x = self.input_tdnn(x) h1 = self.block1(x) h2 = self.block2(h1) h3 = self.block3(h2) cat = torch.cat([h1, h2, h3], dim=1) # (B, 3C, T) cat = F.relu(self.cat_bn(self.cat_tdnn(cat))) pooled = self.pool(cat) # (B, 6C) pooled = self.bn_pool(pooled) emb = self.bn_emb(self.fc_emb(pooled)) # (B, emb_dim) emb = F.normalize(emb, p=2, dim=1) logits = self.classifier(emb) if self.classifier is not None else None if return_intermediates: return emb, logits, [h1, h2, h3] return emb, logits # ───────────────────────────────────────────────────────────────────────────── # Factory helpers # ───────────────────────────────────────────────────────────────────────────── def build_teacher(n_classes: int = 20) -> ECAPA_TDNN: """Full-size teacher (512 ch, 192-dim emb).""" return ECAPA_TDNN(channels=512, emb_dim=192, n_classes=n_classes) def build_student(n_classes: int = 20) -> ECAPA_TDNN: """Half-size student (256 ch, 128-dim emb) — same architecture.""" return ECAPA_TDNN(channels=256, emb_dim=128, n_classes=n_classes) def count_params(model: nn.Module) -> str: n = sum(p.numel() for p in model.parameters() if p.requires_grad) return f"{n/1e6:.2f}M" # ───────────────────────────────────────────────────────────────────────────── # Log-Mel feature extraction (for training) # ───────────────────────────────────────────────────────────────────────────── class LogMelFrontend(nn.Module): """ Differentiable log-Mel spectrogram on GPU. Uses torchaudio.transforms. """ def __init__(self, sample_rate: int = 16_000, n_fft: int = 512, win_length: int = 400, hop_length: int = 160, n_mels: int = 80): super().__init__() import torchaudio.transforms as T self.mel = T.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, n_mels=n_mels, f_min=20, f_max=7600, power=2.0, ) self.db = T.AmplitudeToDB(stype="power", top_db=80) def forward(self, waveform): # (B, T) → (B, 80, T') mel = self.mel(waveform) return self.db(mel)