ECAPA-QAT / src /ecapa_models.py
KIRILLEVS125's picture
Initial release: ECAPA-QAT W(4/8)A32
2999fe8 verified
"""
ECAPA-TDNN Teacher and Student models for dynamic distillation.
Architecture (Desplanques et al., Interspeech 2020):
TDNN (input) → 3x SE-Res2NetBlock → Cat+BN → AttStatPool → FC → L2-emb
Teacher : 512 channels, emb_dim=192 (~14M params)
Student : 256 channels, emb_dim=128 (~3.5M params, same architecture)
The student is trained via dynamic knowledge distillation (see distillation.py).
Input: log-Mel filterbank (80 bins), shape (B, 80, T).
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ─────────────────────────────────────────────────────────────────────────────
# Building blocks
# ─────────────────────────────────────────────────────────────────────────────
class TDNNBlock(nn.Module):
"""Standard Time-Delay Neural Network layer."""
def __init__(self, in_ch: int, out_ch: int, kernel: int = 5,
dilation: int = 1, groups: int = 1):
super().__init__()
pad = (kernel - 1) // 2 * dilation
self.conv = nn.Conv1d(in_ch, out_ch, kernel, dilation=dilation,
padding=pad, groups=groups)
self.bn = nn.BatchNorm1d(out_ch)
self.act = nn.ReLU()
def forward(self, x): # (B, C, T) → (B, out_ch, T)
return self.act(self.bn(self.conv(x)))
class Res2NetConv(nn.Module):
"""Res2Net multi-scale convolution sub-module."""
def __init__(self, channels: int, scale: int = 8, kernel: int = 3,
dilation: int = 2):
super().__init__()
assert channels % scale == 0
self.scale = scale
w = channels // scale
pad = (kernel - 1) // 2 * dilation
self.convs = nn.ModuleList([
nn.Sequential(
nn.Conv1d(w, w, kernel, dilation=dilation, padding=pad),
nn.BatchNorm1d(w),
nn.ReLU(),
)
for _ in range(scale - 1)
])
def forward(self, x): # (B, C, T)
chunks = torch.chunk(x, self.scale, dim=1)
out = [chunks[0]]
for i, conv in enumerate(self.convs):
y = chunks[i + 1] if i == 0 else chunks[i + 1] + out[-1]
out.append(conv(y))
return torch.cat(out, dim=1)
class SEBlock(nn.Module):
"""Squeeze-and-Excitation channel attention."""
def __init__(self, channels: int, bottleneck: int = 128):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(channels, bottleneck),
nn.ReLU(),
nn.Linear(bottleneck, channels),
nn.Sigmoid(),
)
def forward(self, x): # (B, C, T)
s = x.mean(dim=2) # (B, C) global avg
s = self.fc(s).unsqueeze(2) # (B, C, 1)
return x * s
class SERes2NetBlock(nn.Module):
"""
SE-Res2Net block — the core repeating block of ECAPA-TDNN.
in_ch = out_ch (residual connection)
"""
def __init__(self, channels: int, scale: int = 8, se_bottleneck: int = 128,
kernel: int = 3, dilation: int = 2):
super().__init__()
self.tdnn1 = TDNNBlock(channels, channels, 1)
self.res2net = Res2NetConv(channels, scale, kernel, dilation)
self.tdnn2 = TDNNBlock(channels, channels, 1)
self.se = SEBlock(channels, se_bottleneck)
self.bn = nn.BatchNorm1d(channels)
def forward(self, x): # (B, C, T)
residual = x
x = self.tdnn1(x)
x = self.res2net(x)
x = self.tdnn2(x)
x = self.se(x)
return F.relu(self.bn(x + residual))
class AttentiveStatisticsPooling(nn.Module):
"""
Computes attention-weighted mean + std across the time axis.
Input: (B, C, T)
Output: (B, 2C)
"""
def __init__(self, channels: int, attention_dim: int = 128):
super().__init__()
self.attn = nn.Sequential(
nn.Conv1d(channels * 3, attention_dim, 1),
nn.Tanh(),
nn.Conv1d(attention_dim, channels, 1),
nn.Softmax(dim=2),
)
def forward(self, x): # (B, C, T)
# Context vector (global statistics for attention query)
mu = x.mean(dim=2, keepdim=True).expand_as(x)
sg = x.std(dim=2, keepdim=True).expand_as(x)
ctx = torch.cat([x, mu, sg], dim=1) # (B, 3C, T)
alpha = self.attn(ctx) # (B, C, T)
mean = (alpha * x).sum(dim=2) # (B, C)
std = (alpha * (x - mean.unsqueeze(2)).pow(2)).sum(dim=2).clamp(1e-9).sqrt()
return torch.cat([mean, std], dim=1) # (B, 2C)
# ─────────────────────────────────────────────────────────────────────────────
# Full ECAPA-TDNN
# ─────────────────────────────────────────────────────────────────────────────
class ECAPA_TDNN(nn.Module):
"""
Parameters
----------
in_channels : input feature dim (default 80 log-Mel bins)
channels : main TDNN channel width (512 = teacher, 256 = student)
emb_dim : speaker embedding dimension (192 = teacher, 128 = student)
n_classes : number of output classes (for speaker-id softmax head);
set 0 to skip the classification head
scale : Res2Net scale parameter
se_bottleneck: SE reduction bottleneck size
"""
def __init__(self,
in_channels: int = 80,
channels: int = 512,
emb_dim: int = 192,
n_classes: int = 0,
scale: int = 8,
se_bottleneck: int = 128):
super().__init__()
self.channels = channels
self.emb_dim = emb_dim
# Input TDNN
self.input_tdnn = TDNNBlock(in_channels, channels, kernel=5)
# Three SE-Res2Net blocks with increasing dilation
self.block1 = SERes2NetBlock(channels, scale, se_bottleneck, dilation=2)
self.block2 = SERes2NetBlock(channels, scale, se_bottleneck, dilation=3)
self.block3 = SERes2NetBlock(channels, scale, se_bottleneck, dilation=4)
# Aggregation: cat all three block outputs → channel * 3
self.cat_bn = nn.BatchNorm1d(channels * 3)
self.cat_tdnn = nn.Conv1d(channels * 3, channels * 3, 1)
# Attentive statistics pooling → 2 * channel * 3 → emb
self.pool = AttentiveStatisticsPooling(channels * 3,
attention_dim=max(64, channels // 2))
self.bn_pool = nn.BatchNorm1d(channels * 6)
self.fc_emb = nn.Linear(channels * 6, emb_dim)
self.bn_emb = nn.BatchNorm1d(emb_dim)
# Optional classification head (for speaker-id training)
self.classifier = nn.Linear(emb_dim, n_classes) if n_classes > 0 else None
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm1d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x, return_intermediates: bool = False):
"""
x : (B, 80, T) log-Mel filterbank
Returns
-------
emb : (B, emb_dim) L2-normalised speaker embedding
logits : (B, n_classes) or None
intermediates (optional) : list of block output tensors for distillation
"""
x = self.input_tdnn(x)
h1 = self.block1(x)
h2 = self.block2(h1)
h3 = self.block3(h2)
cat = torch.cat([h1, h2, h3], dim=1) # (B, 3C, T)
cat = F.relu(self.cat_bn(self.cat_tdnn(cat)))
pooled = self.pool(cat) # (B, 6C)
pooled = self.bn_pool(pooled)
emb = self.bn_emb(self.fc_emb(pooled)) # (B, emb_dim)
emb = F.normalize(emb, p=2, dim=1)
logits = self.classifier(emb) if self.classifier is not None else None
if return_intermediates:
return emb, logits, [h1, h2, h3]
return emb, logits
# ─────────────────────────────────────────────────────────────────────────────
# Factory helpers
# ─────────────────────────────────────────────────────────────────────────────
def build_teacher(n_classes: int = 20) -> ECAPA_TDNN:
"""Full-size teacher (512 ch, 192-dim emb)."""
return ECAPA_TDNN(channels=512, emb_dim=192, n_classes=n_classes)
def build_student(n_classes: int = 20) -> ECAPA_TDNN:
"""Half-size student (256 ch, 128-dim emb) — same architecture."""
return ECAPA_TDNN(channels=256, emb_dim=128, n_classes=n_classes)
def count_params(model: nn.Module) -> str:
n = sum(p.numel() for p in model.parameters() if p.requires_grad)
return f"{n/1e6:.2f}M"
# ─────────────────────────────────────────────────────────────────────────────
# Log-Mel feature extraction (for training)
# ─────────────────────────────────────────────────────────────────────────────
class LogMelFrontend(nn.Module):
"""
Differentiable log-Mel spectrogram on GPU.
Uses torchaudio.transforms.
"""
def __init__(self, sample_rate: int = 16_000,
n_fft: int = 512,
win_length: int = 400,
hop_length: int = 160,
n_mels: int = 80):
super().__init__()
import torchaudio.transforms as T
self.mel = T.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mels,
f_min=20, f_max=7600,
power=2.0,
)
self.db = T.AmplitudeToDB(stype="power", top_db=80)
def forward(self, waveform): # (B, T) → (B, 80, T')
mel = self.mel(waveform)
return self.db(mel)