|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present |
|
|
|
|
|
URLS = { |
|
|
"hubert-discrete": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-discrete-d49e1c77.pt", |
|
|
"hubert-soft": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-soft-0321fd7e.pt", |
|
|
} |
|
|
|
|
|
class CustomLSTM(nn.Module): |
|
|
def __init__(self, input_sz, hidden_sz): |
|
|
super().__init__() |
|
|
self.input_sz = input_sz |
|
|
self.hidden_size = hidden_sz |
|
|
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4)) |
|
|
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4)) |
|
|
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4)) |
|
|
self.init_weights() |
|
|
|
|
|
def init_weights(self): |
|
|
stdv = 1.0 / math.sqrt(self.hidden_size) |
|
|
for weight in self.parameters(): |
|
|
weight.data.uniform_(-stdv, stdv) |
|
|
|
|
|
def forward(self, x, |
|
|
init_states=None): |
|
|
"""Assumes x is of shape (batch, sequence, feature)""" |
|
|
|
|
|
|
|
|
bs, seq_sz, _ = x.size() |
|
|
hidden_seq = [] |
|
|
if init_states is None: |
|
|
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), |
|
|
torch.zeros(bs, self.hidden_size).to(x.device)) |
|
|
else: |
|
|
h_t, c_t = init_states |
|
|
|
|
|
HS = self.hidden_size |
|
|
for t in range(seq_sz): |
|
|
x_t = x[:, t, :] |
|
|
|
|
|
gates = x_t @ self.W + h_t @ self.U + self.bias |
|
|
i_t, f_t, g_t, o_t = ( |
|
|
torch.sigmoid(gates[:, :HS]), |
|
|
torch.sigmoid(gates[:, HS:HS*2]), |
|
|
torch.tanh(gates[:, HS*2:HS*3]), |
|
|
torch.sigmoid(gates[:, HS*3:]), |
|
|
) |
|
|
c_t = f_t * c_t + i_t * g_t |
|
|
h_t = o_t * torch.tanh(c_t) |
|
|
hidden_seq.append(h_t.unsqueeze(0)) |
|
|
hidden_seq = torch.cat(hidden_seq, dim=0) |
|
|
|
|
|
hidden_seq = hidden_seq.transpose(0, 1).contiguous() |
|
|
return hidden_seq, (h_t, c_t) |
|
|
|
|
|
class AcousticModel(nn.Module): |
|
|
def __init__(self, discrete: bool = False, upsample: bool = True, use_custom_lstm=False): |
|
|
super().__init__() |
|
|
|
|
|
self.encoder = Encoder(discrete, upsample) |
|
|
self.decoder = Decoder(use_custom_lstm=use_custom_lstm) |
|
|
|
|
|
def forward(self, x: torch.Tensor, spk_embs, mels: torch.Tensor) -> torch.Tensor: |
|
|
x = self.encoder(x) |
|
|
exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1) |
|
|
concat_x = torch.cat([x, exp_spk_embs], dim=-1) |
|
|
|
|
|
return self.decoder(concat_x, mels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_test(self, x, spk_embs, mels): |
|
|
print('x shape', x.shape) |
|
|
print('se shape', spk_embs.shape) |
|
|
print('mels shape', mels.shape) |
|
|
x = self.encoder(x) |
|
|
print('x_enc shape', x.shape) |
|
|
return |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate(self, x: torch.Tensor, spk_embs) -> torch.Tensor: |
|
|
x = self.encoder(x) |
|
|
exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1) |
|
|
concat_x = torch.cat([x, exp_spk_embs], dim=-1) |
|
|
|
|
|
return self.decoder.generate(concat_x) |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__(self, discrete: bool = False, upsample: bool = True): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(100 + 1, 256) if discrete else None |
|
|
self.prenet = PreNet(256, 256, 256) |
|
|
self.convs = nn.Sequential( |
|
|
nn.Conv1d(256, 512, 5, 1, 2), |
|
|
nn.ReLU(), |
|
|
nn.InstanceNorm1d(512), |
|
|
nn.ConvTranspose1d(512, 512, 4, 2, 1) if upsample else nn.Identity(), |
|
|
nn.Conv1d(512, 512, 5, 1, 2), |
|
|
nn.ReLU(), |
|
|
nn.InstanceNorm1d(512), |
|
|
nn.Conv1d(512, 512, 5, 1, 2), |
|
|
nn.ReLU(), |
|
|
nn.InstanceNorm1d(512), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
if self.embedding is not None: |
|
|
x = self.embedding(x) |
|
|
x = self.prenet(x) |
|
|
x = self.convs(x.transpose(1, 2)) |
|
|
return x.transpose(1, 2) |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self, use_custom_lstm=False): |
|
|
super().__init__() |
|
|
self.use_custom_lstm = use_custom_lstm |
|
|
self.prenet = PreNet(128, 256, 256) |
|
|
self.prenet = PreNet(128, 256, 256) |
|
|
if use_custom_lstm: |
|
|
self.lstm1 = CustomLSTM(1024 + 256, 768) |
|
|
self.lstm2 = CustomLSTM(768, 768) |
|
|
self.lstm3 = CustomLSTM(768, 768) |
|
|
else: |
|
|
self.lstm1 = nn.LSTM(1024 + 256, 768) |
|
|
self.lstm2 = nn.LSTM(768, 768) |
|
|
self.lstm3 = nn.LSTM(768, 768) |
|
|
self.proj = nn.Linear(768, 128, bias=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor: |
|
|
mels = self.prenet(mels) |
|
|
x, _ = self.lstm1(torch.cat((x, mels), dim=-1)) |
|
|
res = x |
|
|
x, _ = self.lstm2(x) |
|
|
x = res + x |
|
|
res = x |
|
|
x, _ = self.lstm3(x) |
|
|
x = res + x |
|
|
return self.proj(x) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate(self, xs: torch.Tensor) -> torch.Tensor: |
|
|
m = torch.zeros(xs.size(0), 128, device=xs.device) |
|
|
if not self.use_custom_lstm: |
|
|
h1 = torch.zeros(1, xs.size(0), 768, device=xs.device) |
|
|
c1 = torch.zeros(1, xs.size(0), 768, device=xs.device) |
|
|
h2 = torch.zeros(1, xs.size(0), 768, device=xs.device) |
|
|
c2 = torch.zeros(1, xs.size(0), 768, device=xs.device) |
|
|
h3 = torch.zeros(1, xs.size(0), 768, device=xs.device) |
|
|
c3 = torch.zeros(1, xs.size(0), 768, device=xs.device) |
|
|
else: |
|
|
h1 = torch.zeros(xs.size(0), 768, device=xs.device) |
|
|
c1 = torch.zeros(xs.size(0), 768, device=xs.device) |
|
|
h2 = torch.zeros(xs.size(0), 768, device=xs.device) |
|
|
c2 = torch.zeros(xs.size(0), 768, device=xs.device) |
|
|
h3 = torch.zeros(xs.size(0), 768, device=xs.device) |
|
|
c3 = torch.zeros(xs.size(0), 768, device=xs.device) |
|
|
|
|
|
mel = [] |
|
|
for x in torch.unbind(xs, dim=1): |
|
|
m = self.prenet(m) |
|
|
x = torch.cat((x, m), dim=1).unsqueeze(1) |
|
|
x1, (h1, c1) = self.lstm1(x, (h1, c1)) |
|
|
x2, (h2, c2) = self.lstm2(x1, (h2, c2)) |
|
|
x = x1 + x2 |
|
|
x3, (h3, c3) = self.lstm3(x, (h3, c3)) |
|
|
x = x + x3 |
|
|
m = self.proj(x).squeeze(1) |
|
|
mel.append(m) |
|
|
return torch.stack(mel, dim=1) |
|
|
|
|
|
|
|
|
class PreNet(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_size: int, |
|
|
hidden_size: int, |
|
|
output_size: int, |
|
|
dropout: float = 0.5, |
|
|
): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(input_size, hidden_size), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_size, output_size), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
def _acoustic( |
|
|
name: str, |
|
|
discrete: bool, |
|
|
upsample: bool, |
|
|
pretrained: bool = True, |
|
|
progress: bool = True, |
|
|
) -> AcousticModel: |
|
|
acoustic = AcousticModel(discrete, upsample) |
|
|
if pretrained: |
|
|
checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress) |
|
|
consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.") |
|
|
acoustic.load_state_dict(checkpoint["acoustic-model"]) |
|
|
acoustic.eval() |
|
|
return acoustic |
|
|
|
|
|
|
|
|
def hubert_discrete( |
|
|
pretrained: bool = True, |
|
|
progress: bool = True, |
|
|
) -> AcousticModel: |
|
|
r"""HuBERT-Discrete acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. |
|
|
Args: |
|
|
pretrained (bool): load pretrained weights into the model |
|
|
progress (bool): show progress bar when downloading model |
|
|
""" |
|
|
return _acoustic( |
|
|
"hubert-discrete", |
|
|
discrete=True, |
|
|
upsample=True, |
|
|
pretrained=pretrained, |
|
|
progress=progress, |
|
|
) |
|
|
|
|
|
|
|
|
def hubert_soft( |
|
|
pretrained: bool = True, |
|
|
progress: bool = True, |
|
|
) -> AcousticModel: |
|
|
r"""HuBERT-Soft acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. |
|
|
Args: |
|
|
pretrained (bool): load pretrained weights into the model |
|
|
progress (bool): show progress bar when downloading model |
|
|
""" |
|
|
return _acoustic( |
|
|
"hubert-soft", |
|
|
discrete=False, |
|
|
upsample=True, |
|
|
pretrained=pretrained, |
|
|
progress=progress, |
|
|
) |