""" CAM++ speaker encoder. Architecture from: https://github.com/Plachtaa/seed-vc/blob/main/modules/campplus/ Loads pretrained/campplus_cn_common.bin directly. """ from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp import torchaudio # ── Mel-filterbank front-end ────────────────────────────────────────────────── class FBankExtractor(nn.Module): """80-dim log Mel-filterbank at 16 kHz (HTK scale, 25 ms / 10 ms).""" def __init__(self): super().__init__() self.fbank = torchaudio.transforms.MelSpectrogram( sample_rate=16000, n_fft=512, hop_length=160, win_length=400, n_mels=80, f_min=20.0, f_max=7600.0, window_fn=torch.hamming_window, norm=None, mel_scale="htk", ) def forward(self, wav: torch.Tensor) -> torch.Tensor: """wav: (B, T) or (T,) → (B, T_frames, 80)""" if wav.dim() == 1: wav = wav.unsqueeze(0) feats = self.fbank(wav) feats = torch.log(feats.clamp(min=1e-6)) feats = feats - feats.mean(dim=-1, keepdim=True) return feats.transpose(1, 2) # (B, T_frames, 80) # ── Building blocks ─────────────────────────────────────────────────────────── def get_nonlinear(config_str, channels): nonlinear = nn.Sequential() for name in config_str.split('-'): if name == 'relu': nonlinear.add_module('relu', nn.ReLU(inplace=True)) elif name == 'prelu': nonlinear.add_module('prelu', nn.PReLU(channels)) elif name == 'batchnorm': nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels)) elif name == 'batchnorm_': nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels, affine=False)) else: raise ValueError(f'Unexpected module ({name}).') return nonlinear def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True): mean = x.mean(dim=dim) std = x.std(dim=dim, unbiased=unbiased) stats = torch.cat([mean, std], dim=-1) if keepdim: stats = stats.unsqueeze(dim=dim) return stats def masked_statistics_pooling(x, x_lens, dim=-1, keepdim=False, unbiased=True): stats = [] for i, x_len in enumerate(x_lens): xi = x[i, :, :x_len] mean = xi.mean(dim=dim) std = xi.std(dim=dim, unbiased=unbiased) stats.append(torch.cat([mean, std], dim=-1)) stats = torch.stack(stats, dim=0) if keepdim: stats = stats.unsqueeze(dim=dim) return stats class StatsPool(nn.Module): def forward(self, x, x_lens=None): if x_lens is not None: return masked_statistics_pooling(x, x_lens) return statistics_pooling(x) class BasicResBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_planes, planes, 3, stride=(stride, 1), padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, 1, stride=(stride, 1), bias=False), nn.BatchNorm2d(self.expansion * planes), ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) return F.relu(out) class FCM(nn.Module): """2D ResNet front-end: (B, 1, 80, T) → (B, 320, T)""" def __init__(self, block=BasicResBlock, num_blocks=(2, 2), m_channels=32, feat_dim=80): super().__init__() self.in_planes = m_channels self.conv1 = nn.Conv2d(1, m_channels, 3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(m_channels) self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2) self.conv2 = nn.Conv2d(m_channels, m_channels, 3, stride=(2, 1), padding=1, bias=False) self.bn2 = nn.BatchNorm2d(m_channels) self.out_channels = m_channels * (feat_dim // 8) # 32 * 10 = 320 def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for s in strides: layers.append(block(self.in_planes, planes, s)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): x = x.unsqueeze(1) # (B, 1, 80, T) x = F.relu(self.bn1(self.conv1(x))) x = self.layer1(x) x = self.layer2(x) x = F.relu(self.bn2(self.conv2(x))) B, C, freq, T = x.shape return x.reshape(B, C * freq, T) # (B, 320, T) class TDNNLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=False, config_str='batchnorm-relu'): super().__init__() if padding < 0: assert kernel_size % 2 == 1 padding = (kernel_size - 1) // 2 * dilation self.linear = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) self.nonlinear = get_nonlinear(config_str, out_channels) def forward(self, x): return self.nonlinear(self.linear(x)) class CAMLayer(nn.Module): def __init__(self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2): super().__init__() self.linear_local = nn.Conv1d(bn_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1) self.relu = nn.ReLU(inplace=True) self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1) self.sigmoid = nn.Sigmoid() def seg_pooling(self, x, seg_len=100, stype='avg'): if stype == 'avg': seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) else: seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) shape = seg.shape seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1) return seg[..., :x.shape[-1]] def forward(self, x): y = self.linear_local(x) context = x.mean(-1, keepdim=True) + self.seg_pooling(x) context = self.relu(self.linear1(context)) m = self.sigmoid(self.linear2(context)) return y * m class CAMDenseTDNNLayer(nn.Module): def __init__(self, in_channels, out_channels, bn_channels, kernel_size, stride=1, dilation=1, bias=False, config_str='batchnorm-relu', memory_efficient=False): super().__init__() assert kernel_size % 2 == 1 padding = (kernel_size - 1) // 2 * dilation self.memory_efficient = memory_efficient self.nonlinear1 = get_nonlinear(config_str, in_channels) self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False) self.nonlinear2 = get_nonlinear(config_str, bn_channels) self.cam_layer = CAMLayer(bn_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) def bn_function(self, x): return self.linear1(self.nonlinear1(x)) def forward(self, x): if self.training and self.memory_efficient: x = cp.checkpoint(self.bn_function, x) else: x = self.bn_function(x) return self.cam_layer(self.nonlinear2(x)) class CAMDenseTDNNBlock(nn.ModuleList): def __init__(self, num_layers, in_channels, out_channels, bn_channels, kernel_size, stride=1, dilation=1, bias=False, config_str='batchnorm-relu', memory_efficient=False): layers = [ CAMDenseTDNNLayer( in_channels=in_channels + i * out_channels, out_channels=out_channels, bn_channels=bn_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, bias=bias, config_str=config_str, memory_efficient=memory_efficient, ) for i in range(num_layers) ] super().__init__(layers) # Name layers tdnnd1, tdnnd2, ... to match checkpoint keys self._modules = OrderedDict( {f"tdnnd{i+1}": layer for i, layer in enumerate(layers)} ) def forward(self, x): for layer in self: x = torch.cat([x, layer(x)], dim=1) return x class TransitLayer(nn.Module): def __init__(self, in_channels, out_channels, bias=True, config_str='batchnorm-relu'): super().__init__() self.nonlinear = get_nonlinear(config_str, in_channels) self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) def forward(self, x): return self.linear(self.nonlinear(x)) class DenseLayer(nn.Module): def __init__(self, in_channels, out_channels, bias=False, config_str='batchnorm-relu'): super().__init__() self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) self.nonlinear = get_nonlinear(config_str, out_channels) def forward(self, x): if x.dim() == 2: x = self.linear(x.unsqueeze(-1)).squeeze(-1) else: x = self.linear(x) return self.nonlinear(x) # ── CAM++ model ─────────────────────────────────────────────────────────────── class CAMPPlus(nn.Module): def __init__(self, feat_dim=80, embedding_size=192, growth_rate=32, bn_size=4, init_channels=128, config_str='batchnorm-relu', memory_efficient=True): super().__init__() self.head = FCM(feat_dim=feat_dim) channels = self.head.out_channels # 320 self.xvector = nn.Sequential(OrderedDict([ ('tdnn', TDNNLayer(channels, init_channels, 5, stride=2, dilation=1, padding=-1, config_str=config_str)), ])) channels = init_channels # 128 for i, (num_layers, kernel_size, dilation) in enumerate( zip((12, 24, 16), (3, 3, 3), (1, 2, 2)) ): block = CAMDenseTDNNBlock( num_layers=num_layers, in_channels=channels, out_channels=growth_rate, bn_channels=bn_size * growth_rate, kernel_size=kernel_size, dilation=dilation, config_str=config_str, memory_efficient=memory_efficient, ) self.xvector.add_module(f'block{i+1}', block) channels += num_layers * growth_rate self.xvector.add_module( f'transit{i+1}', TransitLayer(channels, channels // 2, bias=False, config_str=config_str), ) channels //= 2 self.xvector.add_module('out_nonlinear', get_nonlinear(config_str, channels)) self.stats = StatsPool() self.dense = DenseLayer(channels * 2, embedding_size, config_str='batchnorm_') for m in self.modules(): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: nn.init.zeros_(m.bias) def load_state_dict(self, state_dict, strict=True): # Remap keys: old checkpoints stored stats/dense inside xvector new_sd = {} for k, v in state_dict.items(): if k.startswith('xvector.stats'): k = k.replace('xvector.stats', 'stats') elif k.startswith('xvector.dense'): k = k.replace('xvector.dense', 'dense') new_sd[k] = v super().load_state_dict(new_sd, strict) def forward(self, x, x_lens=None): x = x.permute(0, 2, 1) # (B, T, 80) → (B, 80, T) x = self.head(x) # (B, 320, T) x = self.xvector(x) # (B, 512, T) x = self.stats(x, x_lens) # (B, 1024) x = self.dense(x) # (B, 192) return x # ── High-level speaker encoder ──────────────────────────────────────────────── class SpeakerEncoder(nn.Module): """Waveform → L2-normalised CAM++ speaker embedding.""" def __init__(self, ckpt_path: str, device: str = "cuda"): super().__init__() self.fbank = FBankExtractor() self.campplus = CAMPPlus() if ckpt_path: self._load_checkpoint(ckpt_path, device) def _load_checkpoint(self, path: str, device: str): ckpt = torch.load(path, map_location=device) sd = ckpt.get("model", ckpt.get("state_dict", ckpt)) self.campplus.load_state_dict(sd, strict=False) self.campplus.eval() print(f"[SpeakerEncoder] Loaded {path}") @torch.no_grad() def extract_embedding(self, wav: torch.Tensor, sr: int = 16000) -> torch.Tensor: """wav: (T,) or (1, T) → (192,) L2-normalised embedding""" if wav.dim() == 2: wav = wav.mean(0) device = next(self.campplus.parameters()).device wav = wav.to(device) if sr != 16000: wav = torchaudio.functional.resample(wav, sr, 16000) feats = self.fbank(wav.unsqueeze(0)) # (1, T_frames, 80) emb = self.campplus(feats).squeeze(0) # (192,) return F.normalize(emb, dim=-1) def forward(self, feats: torch.Tensor) -> torch.Tensor: """feats: (B, T, 80) → (B, 192) L2-normalised""" return F.normalize(self.campplus(feats), dim=-1) # ── Speaker projection (for flow matching model) ────────────────────────────── class SpeakerProjection(nn.Module): def __init__(self, spk_emb_dim: int = 192, hidden_dim: int = 512): super().__init__() self.proj = nn.Sequential( nn.Linear(spk_emb_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), ) def forward(self, spk_emb: torch.Tensor) -> torch.Tensor: return self.proj(spk_emb)