Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present | |
| URLS = { | |
| "hubert-discrete": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-discrete-d49e1c77.pt", | |
| "hubert-soft": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-soft-0321fd7e.pt", | |
| } | |
| class AcousticModel(nn.Module): | |
| def __init__(self, discrete: bool = False, upsample: bool = True): | |
| super().__init__() | |
| self.encoder = Encoder(discrete, upsample) | |
| self.decoder = Decoder() | |
| def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor: | |
| x = self.encoder(x) | |
| return self.decoder(x, mels) | |
| def generate(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.encoder(x) | |
| return self.decoder.generate(x) | |
| class Encoder(nn.Module): | |
| def __init__(self, discrete: bool = False, upsample: bool = True): | |
| super().__init__() | |
| self.embedding = nn.Embedding(100 + 1, 256) if discrete else None | |
| self.prenet = PreNet(256, 256, 256) | |
| self.convs = nn.Sequential( | |
| nn.Conv1d(256, 512, 5, 1, 2), | |
| nn.ReLU(), | |
| nn.InstanceNorm1d(512), | |
| nn.ConvTranspose1d(512, 512, 4, 2, 1) if upsample else nn.Identity(), | |
| nn.Conv1d(512, 512, 5, 1, 2), | |
| nn.ReLU(), | |
| nn.InstanceNorm1d(512), | |
| nn.Conv1d(512, 512, 5, 1, 2), | |
| nn.ReLU(), | |
| nn.InstanceNorm1d(512), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.embedding is not None: | |
| x = self.embedding(x) | |
| x = self.prenet(x) | |
| x = self.convs(x.transpose(1, 2)) | |
| return x.transpose(1, 2) | |
| class Decoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.prenet = PreNet(128, 256, 256) | |
| self.lstm1 = nn.LSTM(512 + 256, 768, batch_first=True) | |
| self.lstm2 = nn.LSTM(768, 768, batch_first=True) | |
| self.lstm3 = nn.LSTM(768, 768, batch_first=True) | |
| self.proj = nn.Linear(768, 128, bias=False) | |
| def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor: | |
| mels = self.prenet(mels) | |
| x, _ = self.lstm1(torch.cat((x, mels), dim=-1)) | |
| res = x | |
| x, _ = self.lstm2(x) | |
| x = res + x | |
| res = x | |
| x, _ = self.lstm3(x) | |
| x = res + x | |
| return self.proj(x) | |
| def generate(self, xs: torch.Tensor) -> torch.Tensor: | |
| m = torch.zeros(xs.size(0), 128, device=xs.device) | |
| h1 = torch.zeros(1, xs.size(0), 768, device=xs.device) | |
| c1 = torch.zeros(1, xs.size(0), 768, device=xs.device) | |
| h2 = torch.zeros(1, xs.size(0), 768, device=xs.device) | |
| c2 = torch.zeros(1, xs.size(0), 768, device=xs.device) | |
| h3 = torch.zeros(1, xs.size(0), 768, device=xs.device) | |
| c3 = torch.zeros(1, xs.size(0), 768, device=xs.device) | |
| mel = [] | |
| for x in torch.unbind(xs, dim=1): | |
| m = self.prenet(m) | |
| x = torch.cat((x, m), dim=1).unsqueeze(1) | |
| x1, (h1, c1) = self.lstm1(x, (h1, c1)) | |
| x2, (h2, c2) = self.lstm2(x1, (h2, c2)) | |
| x = x1 + x2 | |
| x3, (h3, c3) = self.lstm3(x, (h3, c3)) | |
| x = x + x3 | |
| m = self.proj(x).squeeze(1) | |
| mel.append(m) | |
| return torch.stack(mel, dim=1) | |
| class PreNet(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| output_size: int, | |
| dropout: float = 0.5, | |
| ): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size, output_size), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x) | |
| def _acoustic( | |
| name: str, | |
| discrete: bool, | |
| upsample: bool, | |
| pretrained: bool = True, | |
| progress: bool = True, | |
| ) -> AcousticModel: | |
| acoustic = AcousticModel(discrete, upsample) | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress) | |
| consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.") | |
| acoustic.load_state_dict(checkpoint["acoustic-model"]) | |
| acoustic.eval() | |
| return acoustic | |
| def hubert_discrete( | |
| pretrained: bool = True, | |
| progress: bool = True, | |
| ) -> AcousticModel: | |
| r"""HuBERT-Discrete acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. | |
| Args: | |
| pretrained (bool): load pretrained weights into the model | |
| progress (bool): show progress bar when downloading model | |
| """ | |
| return _acoustic( | |
| "hubert-discrete", | |
| discrete=True, | |
| upsample=True, | |
| pretrained=pretrained, | |
| progress=progress, | |
| ) | |
| def hubert_soft( | |
| pretrained: bool = True, | |
| progress: bool = True, | |
| ) -> AcousticModel: | |
| r"""HuBERT-Soft acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. | |
| Args: | |
| pretrained (bool): load pretrained weights into the model | |
| progress (bool): show progress bar when downloading model | |
| """ | |
| return _acoustic( | |
| "hubert-soft", | |
| discrete=False, | |
| upsample=True, | |
| pretrained=pretrained, | |
| progress=progress, | |
| ) | |