| """ |
| CAM++ speaker encoder. |
| Architecture from: https://github.com/Plachtaa/seed-vc/blob/main/modules/campplus/ |
| |
| Loads pretrained/campplus_cn_common.bin directly. |
| """ |
|
|
| from collections import OrderedDict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as cp |
| import torchaudio |
|
|
|
|
| |
|
|
| class FBankExtractor(nn.Module): |
| """80-dim log Mel-filterbank at 16 kHz (HTK scale, 25 ms / 10 ms).""" |
|
|
| def __init__(self): |
| super().__init__() |
| self.fbank = torchaudio.transforms.MelSpectrogram( |
| sample_rate=16000, n_fft=512, hop_length=160, win_length=400, |
| n_mels=80, f_min=20.0, f_max=7600.0, |
| window_fn=torch.hamming_window, norm=None, mel_scale="htk", |
| ) |
|
|
| def forward(self, wav: torch.Tensor) -> torch.Tensor: |
| """wav: (B, T) or (T,) β (B, T_frames, 80)""" |
| if wav.dim() == 1: |
| wav = wav.unsqueeze(0) |
| feats = self.fbank(wav) |
| feats = torch.log(feats.clamp(min=1e-6)) |
| feats = feats - feats.mean(dim=-1, keepdim=True) |
| return feats.transpose(1, 2) |
|
|
|
|
| |
|
|
| def get_nonlinear(config_str, channels): |
| nonlinear = nn.Sequential() |
| for name in config_str.split('-'): |
| if name == 'relu': |
| nonlinear.add_module('relu', nn.ReLU(inplace=True)) |
| elif name == 'prelu': |
| nonlinear.add_module('prelu', nn.PReLU(channels)) |
| elif name == 'batchnorm': |
| nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels)) |
| elif name == 'batchnorm_': |
| nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels, affine=False)) |
| else: |
| raise ValueError(f'Unexpected module ({name}).') |
| return nonlinear |
|
|
|
|
| def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True): |
| mean = x.mean(dim=dim) |
| std = x.std(dim=dim, unbiased=unbiased) |
| stats = torch.cat([mean, std], dim=-1) |
| if keepdim: |
| stats = stats.unsqueeze(dim=dim) |
| return stats |
|
|
|
|
| def masked_statistics_pooling(x, x_lens, dim=-1, keepdim=False, unbiased=True): |
| stats = [] |
| for i, x_len in enumerate(x_lens): |
| xi = x[i, :, :x_len] |
| mean = xi.mean(dim=dim) |
| std = xi.std(dim=dim, unbiased=unbiased) |
| stats.append(torch.cat([mean, std], dim=-1)) |
| stats = torch.stack(stats, dim=0) |
| if keepdim: |
| stats = stats.unsqueeze(dim=dim) |
| return stats |
|
|
|
|
| class StatsPool(nn.Module): |
| def forward(self, x, x_lens=None): |
| if x_lens is not None: |
| return masked_statistics_pooling(x, x_lens) |
| return statistics_pooling(x) |
|
|
|
|
| class BasicResBlock(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, in_planes, planes, stride=1): |
| super().__init__() |
| self.conv1 = nn.Conv2d(in_planes, planes, 3, |
| stride=(stride, 1), padding=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.shortcut = nn.Sequential() |
| if stride != 1 or in_planes != self.expansion * planes: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d(in_planes, self.expansion * planes, 1, |
| stride=(stride, 1), bias=False), |
| nn.BatchNorm2d(self.expansion * planes), |
| ) |
|
|
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| out += self.shortcut(x) |
| return F.relu(out) |
|
|
|
|
| class FCM(nn.Module): |
| """2D ResNet front-end: (B, 1, 80, T) β (B, 320, T)""" |
|
|
| def __init__(self, block=BasicResBlock, num_blocks=(2, 2), |
| m_channels=32, feat_dim=80): |
| super().__init__() |
| self.in_planes = m_channels |
| self.conv1 = nn.Conv2d(1, m_channels, 3, stride=1, padding=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(m_channels) |
| self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) |
| self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2) |
| self.conv2 = nn.Conv2d(m_channels, m_channels, 3, |
| stride=(2, 1), padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(m_channels) |
| self.out_channels = m_channels * (feat_dim // 8) |
|
|
| def _make_layer(self, block, planes, num_blocks, stride): |
| strides = [stride] + [1] * (num_blocks - 1) |
| layers = [] |
| for s in strides: |
| layers.append(block(self.in_planes, planes, s)) |
| self.in_planes = planes * block.expansion |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| x = x.unsqueeze(1) |
| x = F.relu(self.bn1(self.conv1(x))) |
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = F.relu(self.bn2(self.conv2(x))) |
| B, C, freq, T = x.shape |
| return x.reshape(B, C * freq, T) |
|
|
|
|
| class TDNNLayer(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, |
| stride=1, padding=0, dilation=1, bias=False, |
| config_str='batchnorm-relu'): |
| super().__init__() |
| if padding < 0: |
| assert kernel_size % 2 == 1 |
| padding = (kernel_size - 1) // 2 * dilation |
| self.linear = nn.Conv1d(in_channels, out_channels, kernel_size, |
| stride=stride, padding=padding, |
| dilation=dilation, bias=bias) |
| self.nonlinear = get_nonlinear(config_str, out_channels) |
|
|
| def forward(self, x): |
| return self.nonlinear(self.linear(x)) |
|
|
|
|
| class CAMLayer(nn.Module): |
| def __init__(self, bn_channels, out_channels, kernel_size, |
| stride, padding, dilation, bias, reduction=2): |
| super().__init__() |
| self.linear_local = nn.Conv1d(bn_channels, out_channels, kernel_size, |
| stride=stride, padding=padding, |
| dilation=dilation, bias=bias) |
| self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1) |
| self.relu = nn.ReLU(inplace=True) |
| self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def seg_pooling(self, x, seg_len=100, stype='avg'): |
| if stype == 'avg': |
| seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) |
| else: |
| seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) |
| shape = seg.shape |
| seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1) |
| return seg[..., :x.shape[-1]] |
|
|
| def forward(self, x): |
| y = self.linear_local(x) |
| context = x.mean(-1, keepdim=True) + self.seg_pooling(x) |
| context = self.relu(self.linear1(context)) |
| m = self.sigmoid(self.linear2(context)) |
| return y * m |
|
|
|
|
| class CAMDenseTDNNLayer(nn.Module): |
| def __init__(self, in_channels, out_channels, bn_channels, |
| kernel_size, stride=1, dilation=1, bias=False, |
| config_str='batchnorm-relu', memory_efficient=False): |
| super().__init__() |
| assert kernel_size % 2 == 1 |
| padding = (kernel_size - 1) // 2 * dilation |
| self.memory_efficient = memory_efficient |
| self.nonlinear1 = get_nonlinear(config_str, in_channels) |
| self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False) |
| self.nonlinear2 = get_nonlinear(config_str, bn_channels) |
| self.cam_layer = CAMLayer(bn_channels, out_channels, kernel_size, |
| stride=stride, padding=padding, |
| dilation=dilation, bias=bias) |
|
|
| def bn_function(self, x): |
| return self.linear1(self.nonlinear1(x)) |
|
|
| def forward(self, x): |
| if self.training and self.memory_efficient: |
| x = cp.checkpoint(self.bn_function, x) |
| else: |
| x = self.bn_function(x) |
| return self.cam_layer(self.nonlinear2(x)) |
|
|
|
|
| class CAMDenseTDNNBlock(nn.ModuleList): |
| def __init__(self, num_layers, in_channels, out_channels, bn_channels, |
| kernel_size, stride=1, dilation=1, bias=False, |
| config_str='batchnorm-relu', memory_efficient=False): |
| layers = [ |
| CAMDenseTDNNLayer( |
| in_channels=in_channels + i * out_channels, |
| out_channels=out_channels, |
| bn_channels=bn_channels, |
| kernel_size=kernel_size, |
| stride=stride, dilation=dilation, bias=bias, |
| config_str=config_str, memory_efficient=memory_efficient, |
| ) |
| for i in range(num_layers) |
| ] |
| super().__init__(layers) |
| |
| self._modules = OrderedDict( |
| {f"tdnnd{i+1}": layer for i, layer in enumerate(layers)} |
| ) |
|
|
| def forward(self, x): |
| for layer in self: |
| x = torch.cat([x, layer(x)], dim=1) |
| return x |
|
|
|
|
| class TransitLayer(nn.Module): |
| def __init__(self, in_channels, out_channels, bias=True, |
| config_str='batchnorm-relu'): |
| super().__init__() |
| self.nonlinear = get_nonlinear(config_str, in_channels) |
| self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) |
|
|
| def forward(self, x): |
| return self.linear(self.nonlinear(x)) |
|
|
|
|
| class DenseLayer(nn.Module): |
| def __init__(self, in_channels, out_channels, bias=False, |
| config_str='batchnorm-relu'): |
| super().__init__() |
| self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) |
| self.nonlinear = get_nonlinear(config_str, out_channels) |
|
|
| def forward(self, x): |
| if x.dim() == 2: |
| x = self.linear(x.unsqueeze(-1)).squeeze(-1) |
| else: |
| x = self.linear(x) |
| return self.nonlinear(x) |
|
|
|
|
| |
|
|
| class CAMPPlus(nn.Module): |
| def __init__(self, feat_dim=80, embedding_size=192, |
| growth_rate=32, bn_size=4, init_channels=128, |
| config_str='batchnorm-relu', memory_efficient=True): |
| super().__init__() |
|
|
| self.head = FCM(feat_dim=feat_dim) |
| channels = self.head.out_channels |
|
|
| self.xvector = nn.Sequential(OrderedDict([ |
| ('tdnn', TDNNLayer(channels, init_channels, 5, |
| stride=2, dilation=1, padding=-1, |
| config_str=config_str)), |
| ])) |
| channels = init_channels |
|
|
| for i, (num_layers, kernel_size, dilation) in enumerate( |
| zip((12, 24, 16), (3, 3, 3), (1, 2, 2)) |
| ): |
| block = CAMDenseTDNNBlock( |
| num_layers=num_layers, |
| in_channels=channels, |
| out_channels=growth_rate, |
| bn_channels=bn_size * growth_rate, |
| kernel_size=kernel_size, |
| dilation=dilation, |
| config_str=config_str, |
| memory_efficient=memory_efficient, |
| ) |
| self.xvector.add_module(f'block{i+1}', block) |
| channels += num_layers * growth_rate |
| self.xvector.add_module( |
| f'transit{i+1}', |
| TransitLayer(channels, channels // 2, bias=False, |
| config_str=config_str), |
| ) |
| channels //= 2 |
|
|
| self.xvector.add_module('out_nonlinear', get_nonlinear(config_str, channels)) |
|
|
| self.stats = StatsPool() |
| self.dense = DenseLayer(channels * 2, embedding_size, config_str='batchnorm_') |
|
|
| for m in self.modules(): |
| if isinstance(m, (nn.Conv1d, nn.Linear)): |
| nn.init.kaiming_normal_(m.weight.data) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def load_state_dict(self, state_dict, strict=True): |
| |
| new_sd = {} |
| for k, v in state_dict.items(): |
| if k.startswith('xvector.stats'): |
| k = k.replace('xvector.stats', 'stats') |
| elif k.startswith('xvector.dense'): |
| k = k.replace('xvector.dense', 'dense') |
| new_sd[k] = v |
| super().load_state_dict(new_sd, strict) |
|
|
| def forward(self, x, x_lens=None): |
| x = x.permute(0, 2, 1) |
| x = self.head(x) |
| x = self.xvector(x) |
| x = self.stats(x, x_lens) |
| x = self.dense(x) |
| return x |
|
|
|
|
| |
|
|
| class SpeakerEncoder(nn.Module): |
| """Waveform β L2-normalised CAM++ speaker embedding.""" |
|
|
| def __init__(self, ckpt_path: str, device: str = "cuda"): |
| super().__init__() |
| self.fbank = FBankExtractor() |
| self.campplus = CAMPPlus() |
| if ckpt_path: |
| self._load_checkpoint(ckpt_path, device) |
|
|
| def _load_checkpoint(self, path: str, device: str): |
| ckpt = torch.load(path, map_location=device) |
| sd = ckpt.get("model", ckpt.get("state_dict", ckpt)) |
| self.campplus.load_state_dict(sd, strict=False) |
| self.campplus.eval() |
| print(f"[SpeakerEncoder] Loaded {path}") |
|
|
| @torch.no_grad() |
| def extract_embedding(self, wav: torch.Tensor, sr: int = 16000) -> torch.Tensor: |
| """wav: (T,) or (1, T) β (192,) L2-normalised embedding""" |
| if wav.dim() == 2: |
| wav = wav.mean(0) |
| device = next(self.campplus.parameters()).device |
| wav = wav.to(device) |
| if sr != 16000: |
| wav = torchaudio.functional.resample(wav, sr, 16000) |
| feats = self.fbank(wav.unsqueeze(0)) |
| emb = self.campplus(feats).squeeze(0) |
| return F.normalize(emb, dim=-1) |
|
|
| def forward(self, feats: torch.Tensor) -> torch.Tensor: |
| """feats: (B, T, 80) β (B, 192) L2-normalised""" |
| return F.normalize(self.campplus(feats), dim=-1) |
|
|
|
|
| |
|
|
| class SpeakerProjection(nn.Module): |
| def __init__(self, spk_emb_dim: int = 192, hidden_dim: int = 512): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(spk_emb_dim, hidden_dim), |
| nn.SiLU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| ) |
|
|
| def forward(self, spk_emb: torch.Tensor) -> torch.Tensor: |
| return self.proj(spk_emb) |
|
|