| | |
| | |
| | |
| | |
| |
|
| |
|
| | import librosa |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.utils.data |
| | import torchaudio |
| |
|
| |
|
| | EMBEDDER_PARAMS = { |
| | 'num_mels': 40, |
| | 'n_fft': 512, |
| | 'emb_dim': 256, |
| | 'lstm_hidden': 768, |
| | 'lstm_layers': 3, |
| | 'window': 80, |
| | 'stride': 40, |
| | } |
| |
|
| |
|
| | def set_requires_grad(nets, requires_grad=False): |
| | """Set requies_grad=Fasle for all the networks to avoid unnecessary |
| | computations |
| | Parameters: |
| | nets (network list) -- a list of networks |
| | requires_grad (bool) -- whether the networks require gradients or not |
| | """ |
| | if not isinstance(nets, list): |
| | nets = [nets] |
| | for net in nets: |
| | if net is not None: |
| | for param in net.parameters(): |
| | param.requires_grad = requires_grad |
| |
|
| |
|
| | class LinearNorm(nn.Module): |
| | def __init__(self, hp): |
| | super(LinearNorm, self).__init__() |
| | self.linear_layer = nn.Linear(hp["lstm_hidden"], hp["emb_dim"]) |
| |
|
| | def forward(self, x): |
| | return self.linear_layer(x) |
| |
|
| |
|
| | class SpeechEmbedder(nn.Module): |
| | def __init__(self, hp): |
| | super(SpeechEmbedder, self).__init__() |
| | self.lstm = nn.LSTM(hp["num_mels"], |
| | hp["lstm_hidden"], |
| | num_layers=hp["lstm_layers"], |
| | batch_first=True) |
| | self.proj = LinearNorm(hp) |
| | self.hp = hp |
| |
|
| | def forward(self, mel): |
| | |
| | mels = mel.unfold(1, self.hp["window"], self.hp["stride"]) |
| | mels = mels.permute(1, 2, 0) |
| | x, _ = self.lstm(mels) |
| | x = x[:, -1, :] |
| | x = self.proj(x) |
| | x = x / torch.norm(x, p=2, dim=1, keepdim=True) |
| |
|
| | x = x.mean(dim=0) |
| | if x.norm(p=2) != 0: |
| | x = x / x.norm(p=2) |
| | return x |
| |
|
| |
|
| | class SpkrEmbedder(nn.Module): |
| | RATE = 16000 |
| |
|
| | def __init__( |
| | self, |
| | embedder_path, |
| | embedder_params=EMBEDDER_PARAMS, |
| | rate=16000, |
| | hop_length=160, |
| | win_length=400, |
| | pad=False, |
| | ): |
| | super(SpkrEmbedder, self).__init__() |
| | embedder_pt = torch.load(embedder_path, map_location="cpu") |
| | self.embedder = SpeechEmbedder(embedder_params) |
| | self.embedder.load_state_dict(embedder_pt) |
| | self.embedder.eval() |
| | set_requires_grad(self.embedder, requires_grad=False) |
| | self.embedder_params = embedder_params |
| |
|
| | self.register_buffer('mel_basis', torch.from_numpy( |
| | librosa.filters.mel( |
| | sr=self.RATE, |
| | n_fft=self.embedder_params["n_fft"], |
| | n_mels=self.embedder_params["num_mels"]) |
| | ) |
| | ) |
| |
|
| | self.resample = None |
| | if rate != self.RATE: |
| | self.resample = torchaudio.transforms.Resample(rate, self.RATE) |
| | self.hop_length = hop_length |
| | self.win_length = win_length |
| | self.pad = pad |
| |
|
| | def get_mel(self, y): |
| | if self.pad and y.shape[-1] < 14000: |
| | y = F.pad(y, (0, 14000 - y.shape[-1])) |
| |
|
| | window = torch.hann_window(self.win_length).to(y) |
| | y = torch.stft(y, n_fft=self.embedder_params["n_fft"], |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | window=window) |
| | magnitudes = torch.norm(y, dim=-1, p=2) ** 2 |
| | mel = torch.log10(self.mel_basis @ magnitudes + 1e-6) |
| | return mel |
| |
|
| | def forward(self, inputs): |
| | dvecs = [] |
| | for wav in inputs: |
| | mel = self.get_mel(wav) |
| | if mel.dim() == 3: |
| | mel = mel.squeeze(0) |
| | dvecs += [self.embedder(mel)] |
| | dvecs = torch.stack(dvecs) |
| |
|
| | dvec = torch.mean(dvecs, dim=0) |
| | dvec = dvec / torch.norm(dvec) |
| |
|
| | return dvec |
| |
|