sunf / flow_matching /speaker_encoder.py
anhtunguyen98's picture
Upload folder using huggingface_hub
4698bfc verified
"""
CAM++ speaker encoder.
Architecture from: https://github.com/Plachtaa/seed-vc/blob/main/modules/campplus/
Loads pretrained/campplus_cn_common.bin directly.
"""
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
import torchaudio
# ── Mel-filterbank front-end ──────────────────────────────────────────────────
class FBankExtractor(nn.Module):
"""80-dim log Mel-filterbank at 16 kHz (HTK scale, 25 ms / 10 ms)."""
def __init__(self):
super().__init__()
self.fbank = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, n_fft=512, hop_length=160, win_length=400,
n_mels=80, f_min=20.0, f_max=7600.0,
window_fn=torch.hamming_window, norm=None, mel_scale="htk",
)
def forward(self, wav: torch.Tensor) -> torch.Tensor:
"""wav: (B, T) or (T,) β†’ (B, T_frames, 80)"""
if wav.dim() == 1:
wav = wav.unsqueeze(0)
feats = self.fbank(wav)
feats = torch.log(feats.clamp(min=1e-6))
feats = feats - feats.mean(dim=-1, keepdim=True)
return feats.transpose(1, 2) # (B, T_frames, 80)
# ── Building blocks ───────────────────────────────────────────────────────────
def get_nonlinear(config_str, channels):
nonlinear = nn.Sequential()
for name in config_str.split('-'):
if name == 'relu':
nonlinear.add_module('relu', nn.ReLU(inplace=True))
elif name == 'prelu':
nonlinear.add_module('prelu', nn.PReLU(channels))
elif name == 'batchnorm':
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
elif name == 'batchnorm_':
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels, affine=False))
else:
raise ValueError(f'Unexpected module ({name}).')
return nonlinear
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True):
mean = x.mean(dim=dim)
std = x.std(dim=dim, unbiased=unbiased)
stats = torch.cat([mean, std], dim=-1)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
def masked_statistics_pooling(x, x_lens, dim=-1, keepdim=False, unbiased=True):
stats = []
for i, x_len in enumerate(x_lens):
xi = x[i, :, :x_len]
mean = xi.mean(dim=dim)
std = xi.std(dim=dim, unbiased=unbiased)
stats.append(torch.cat([mean, std], dim=-1))
stats = torch.stack(stats, dim=0)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
class StatsPool(nn.Module):
def forward(self, x, x_lens=None):
if x_lens is not None:
return masked_statistics_pooling(x, x_lens)
return statistics_pooling(x)
class BasicResBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_planes, planes, 3,
stride=(stride, 1), padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, 1,
stride=(stride, 1), bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return F.relu(out)
class FCM(nn.Module):
"""2D ResNet front-end: (B, 1, 80, T) β†’ (B, 320, T)"""
def __init__(self, block=BasicResBlock, num_blocks=(2, 2),
m_channels=32, feat_dim=80):
super().__init__()
self.in_planes = m_channels
self.conv1 = nn.Conv2d(1, m_channels, 3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
self.conv2 = nn.Conv2d(m_channels, m_channels, 3,
stride=(2, 1), padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(m_channels)
self.out_channels = m_channels * (feat_dim // 8) # 32 * 10 = 320
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for s in strides:
layers.append(block(self.in_planes, planes, s))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze(1) # (B, 1, 80, T)
x = F.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = F.relu(self.bn2(self.conv2(x)))
B, C, freq, T = x.shape
return x.reshape(B, C * freq, T) # (B, 320, T)
class TDNNLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, bias=False,
config_str='batchnorm-relu'):
super().__init__()
if padding < 0:
assert kernel_size % 2 == 1
padding = (kernel_size - 1) // 2 * dilation
self.linear = nn.Conv1d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding,
dilation=dilation, bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
return self.nonlinear(self.linear(x))
class CAMLayer(nn.Module):
def __init__(self, bn_channels, out_channels, kernel_size,
stride, padding, dilation, bias, reduction=2):
super().__init__()
self.linear_local = nn.Conv1d(bn_channels, out_channels, kernel_size,
stride=stride, padding=padding,
dilation=dilation, bias=bias)
self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
self.relu = nn.ReLU(inplace=True)
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
self.sigmoid = nn.Sigmoid()
def seg_pooling(self, x, seg_len=100, stype='avg'):
if stype == 'avg':
seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
else:
seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
shape = seg.shape
seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
return seg[..., :x.shape[-1]]
def forward(self, x):
y = self.linear_local(x)
context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
context = self.relu(self.linear1(context))
m = self.sigmoid(self.linear2(context))
return y * m
class CAMDenseTDNNLayer(nn.Module):
def __init__(self, in_channels, out_channels, bn_channels,
kernel_size, stride=1, dilation=1, bias=False,
config_str='batchnorm-relu', memory_efficient=False):
super().__init__()
assert kernel_size % 2 == 1
padding = (kernel_size - 1) // 2 * dilation
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
self.cam_layer = CAMLayer(bn_channels, out_channels, kernel_size,
stride=stride, padding=padding,
dilation=dilation, bias=bias)
def bn_function(self, x):
return self.linear1(self.nonlinear1(x))
def forward(self, x):
if self.training and self.memory_efficient:
x = cp.checkpoint(self.bn_function, x)
else:
x = self.bn_function(x)
return self.cam_layer(self.nonlinear2(x))
class CAMDenseTDNNBlock(nn.ModuleList):
def __init__(self, num_layers, in_channels, out_channels, bn_channels,
kernel_size, stride=1, dilation=1, bias=False,
config_str='batchnorm-relu', memory_efficient=False):
layers = [
CAMDenseTDNNLayer(
in_channels=in_channels + i * out_channels,
out_channels=out_channels,
bn_channels=bn_channels,
kernel_size=kernel_size,
stride=stride, dilation=dilation, bias=bias,
config_str=config_str, memory_efficient=memory_efficient,
)
for i in range(num_layers)
]
super().__init__(layers)
# Name layers tdnnd1, tdnnd2, ... to match checkpoint keys
self._modules = OrderedDict(
{f"tdnnd{i+1}": layer for i, layer in enumerate(layers)}
)
def forward(self, x):
for layer in self:
x = torch.cat([x, layer(x)], dim=1)
return x
class TransitLayer(nn.Module):
def __init__(self, in_channels, out_channels, bias=True,
config_str='batchnorm-relu'):
super().__init__()
self.nonlinear = get_nonlinear(config_str, in_channels)
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
return self.linear(self.nonlinear(x))
class DenseLayer(nn.Module):
def __init__(self, in_channels, out_channels, bias=False,
config_str='batchnorm-relu'):
super().__init__()
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
if x.dim() == 2:
x = self.linear(x.unsqueeze(-1)).squeeze(-1)
else:
x = self.linear(x)
return self.nonlinear(x)
# ── CAM++ model ───────────────────────────────────────────────────────────────
class CAMPPlus(nn.Module):
def __init__(self, feat_dim=80, embedding_size=192,
growth_rate=32, bn_size=4, init_channels=128,
config_str='batchnorm-relu', memory_efficient=True):
super().__init__()
self.head = FCM(feat_dim=feat_dim)
channels = self.head.out_channels # 320
self.xvector = nn.Sequential(OrderedDict([
('tdnn', TDNNLayer(channels, init_channels, 5,
stride=2, dilation=1, padding=-1,
config_str=config_str)),
]))
channels = init_channels # 128
for i, (num_layers, kernel_size, dilation) in enumerate(
zip((12, 24, 16), (3, 3, 3), (1, 2, 2))
):
block = CAMDenseTDNNBlock(
num_layers=num_layers,
in_channels=channels,
out_channels=growth_rate,
bn_channels=bn_size * growth_rate,
kernel_size=kernel_size,
dilation=dilation,
config_str=config_str,
memory_efficient=memory_efficient,
)
self.xvector.add_module(f'block{i+1}', block)
channels += num_layers * growth_rate
self.xvector.add_module(
f'transit{i+1}',
TransitLayer(channels, channels // 2, bias=False,
config_str=config_str),
)
channels //= 2
self.xvector.add_module('out_nonlinear', get_nonlinear(config_str, channels))
self.stats = StatsPool()
self.dense = DenseLayer(channels * 2, embedding_size, config_str='batchnorm_')
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
def load_state_dict(self, state_dict, strict=True):
# Remap keys: old checkpoints stored stats/dense inside xvector
new_sd = {}
for k, v in state_dict.items():
if k.startswith('xvector.stats'):
k = k.replace('xvector.stats', 'stats')
elif k.startswith('xvector.dense'):
k = k.replace('xvector.dense', 'dense')
new_sd[k] = v
super().load_state_dict(new_sd, strict)
def forward(self, x, x_lens=None):
x = x.permute(0, 2, 1) # (B, T, 80) β†’ (B, 80, T)
x = self.head(x) # (B, 320, T)
x = self.xvector(x) # (B, 512, T)
x = self.stats(x, x_lens) # (B, 1024)
x = self.dense(x) # (B, 192)
return x
# ── High-level speaker encoder ────────────────────────────────────────────────
class SpeakerEncoder(nn.Module):
"""Waveform β†’ L2-normalised CAM++ speaker embedding."""
def __init__(self, ckpt_path: str, device: str = "cuda"):
super().__init__()
self.fbank = FBankExtractor()
self.campplus = CAMPPlus()
if ckpt_path:
self._load_checkpoint(ckpt_path, device)
def _load_checkpoint(self, path: str, device: str):
ckpt = torch.load(path, map_location=device)
sd = ckpt.get("model", ckpt.get("state_dict", ckpt))
self.campplus.load_state_dict(sd, strict=False)
self.campplus.eval()
print(f"[SpeakerEncoder] Loaded {path}")
@torch.no_grad()
def extract_embedding(self, wav: torch.Tensor, sr: int = 16000) -> torch.Tensor:
"""wav: (T,) or (1, T) β†’ (192,) L2-normalised embedding"""
if wav.dim() == 2:
wav = wav.mean(0)
device = next(self.campplus.parameters()).device
wav = wav.to(device)
if sr != 16000:
wav = torchaudio.functional.resample(wav, sr, 16000)
feats = self.fbank(wav.unsqueeze(0)) # (1, T_frames, 80)
emb = self.campplus(feats).squeeze(0) # (192,)
return F.normalize(emb, dim=-1)
def forward(self, feats: torch.Tensor) -> torch.Tensor:
"""feats: (B, T, 80) β†’ (B, 192) L2-normalised"""
return F.normalize(self.campplus(feats), dim=-1)
# ── Speaker projection (for flow matching model) ──────────────────────────────
class SpeakerProjection(nn.Module):
def __init__(self, spk_emb_dim: int = 192, hidden_dim: int = 512):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(spk_emb_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)
def forward(self, spk_emb: torch.Tensor) -> torch.Tensor:
return self.proj(spk_emb)