Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from so_vits_svc_fork.modules import attentions as attentions | |
| from so_vits_svc_fork.modules import commons as commons | |
| from so_vits_svc_fork.modules import modules as modules | |
| class SpeakerEncoder(torch.nn.Module): | |
| def __init__( | |
| self, | |
| mel_n_channels=80, | |
| model_num_layers=3, | |
| model_hidden_size=256, | |
| model_embedding_size=256, | |
| ): | |
| super().__init__() | |
| self.lstm = nn.LSTM( | |
| mel_n_channels, model_hidden_size, model_num_layers, batch_first=True | |
| ) | |
| self.linear = nn.Linear(model_hidden_size, model_embedding_size) | |
| self.relu = nn.ReLU() | |
| def forward(self, mels): | |
| self.lstm.flatten_parameters() | |
| _, (hidden, _) = self.lstm(mels) | |
| embeds_raw = self.relu(self.linear(hidden[-1])) | |
| return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) | |
| def compute_partial_slices(self, total_frames, partial_frames, partial_hop): | |
| mel_slices = [] | |
| for i in range(0, total_frames - partial_frames, partial_hop): | |
| mel_range = torch.arange(i, i + partial_frames) | |
| mel_slices.append(mel_range) | |
| return mel_slices | |
| def embed_utterance(self, mel, partial_frames=128, partial_hop=64): | |
| mel_len = mel.size(1) | |
| last_mel = mel[:, -partial_frames:] | |
| if mel_len > partial_frames: | |
| mel_slices = self.compute_partial_slices( | |
| mel_len, partial_frames, partial_hop | |
| ) | |
| mels = list(mel[:, s] for s in mel_slices) | |
| mels.append(last_mel) | |
| mels = torch.stack(tuple(mels), 0).squeeze(1) | |
| with torch.no_grad(): | |
| partial_embeds = self(mels) | |
| embed = torch.mean(partial_embeds, axis=0).unsqueeze(0) | |
| # embed = embed / torch.linalg.norm(embed, 2) | |
| else: | |
| with torch.no_grad(): | |
| embed = self(last_mel) | |
| return embed | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| hidden_channels, | |
| kernel_size, | |
| dilation_rate, | |
| n_layers, | |
| gin_channels=0, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.hidden_channels = hidden_channels | |
| self.kernel_size = kernel_size | |
| self.dilation_rate = dilation_rate | |
| self.n_layers = n_layers | |
| self.gin_channels = gin_channels | |
| self.pre = nn.Conv1d(in_channels, hidden_channels, 1) | |
| self.enc = modules.WN( | |
| hidden_channels, | |
| kernel_size, | |
| dilation_rate, | |
| n_layers, | |
| gin_channels=gin_channels, | |
| ) | |
| self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
| def forward(self, x, x_lengths, g=None): | |
| # print(x.shape,x_lengths.shape) | |
| x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( | |
| x.dtype | |
| ) | |
| x = self.pre(x) * x_mask | |
| x = self.enc(x, x_mask, g=g) | |
| stats = self.proj(x) * x_mask | |
| m, logs = torch.split(stats, self.out_channels, dim=1) | |
| z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask | |
| return z, m, logs, x_mask | |
| class TextEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| out_channels, | |
| hidden_channels, | |
| kernel_size, | |
| n_layers, | |
| gin_channels=0, | |
| filter_channels=None, | |
| n_heads=None, | |
| p_dropout=None, | |
| ): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.hidden_channels = hidden_channels | |
| self.kernel_size = kernel_size | |
| self.n_layers = n_layers | |
| self.gin_channels = gin_channels | |
| self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
| self.f0_emb = nn.Embedding(256, hidden_channels) | |
| self.enc_ = attentions.Encoder( | |
| hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout | |
| ) | |
| def forward(self, x, x_mask, f0=None, noice_scale=1): | |
| x = x + self.f0_emb(f0).transpose(1, 2) | |
| x = self.enc_(x * x_mask, x_mask) | |
| stats = self.proj(x) * x_mask | |
| m, logs = torch.split(stats, self.out_channels, dim=1) | |
| z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask | |
| return z, m, logs, x_mask | |