Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn, Tensor | |
| from torch.optim import Optimizer | |
| from .blocks import UpResConvBlock, DownResConvBlock | |
| import lightning as L | |
| from auraloss.freq import MultiResolutionSTFTLoss | |
| from typing import Sequence | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| in_features: int, | |
| out_features: int, | |
| channels: Sequence[int], | |
| ) -> None: | |
| super(Encoder, self).__init__() | |
| assert ( | |
| in_features % 2 ** len(channels) == 0 | |
| ), f"in_features ({in_features}) must be a multiple of downscale factor ({2**len(channels)})" | |
| modules = [nn.Conv1d(in_channels, channels[0], 1), nn.GELU()] | |
| for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]): | |
| modules += [ | |
| DownResConvBlock(in_channel, out_channel, 1), | |
| ] | |
| n_features = int(in_features * 0.5 ** len(channels)) | |
| modules += [ | |
| nn.Flatten(), | |
| nn.Linear(n_features * channels[-1], 2 * out_features), | |
| ] | |
| self.net = nn.Sequential(*modules) | |
| def forward(self, x: Tensor) -> Tensor: | |
| mean, logvar = self.net(x).chunk(2, dim=1) | |
| return mean, logvar | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| out_channels: int, | |
| in_features: int, | |
| out_features: int, | |
| channels: Sequence[int], | |
| ) -> None: | |
| super(Decoder, self).__init__() | |
| n_features = int(out_features / 2 ** len(channels)) | |
| modules = [ | |
| nn.Linear(in_features, n_features * channels[0]), | |
| nn.Unflatten(-1, (channels[0], n_features)), | |
| ] | |
| for in_channel, out_channel in zip(channels, channels[1:] + [channels[-1]]): | |
| modules += [ | |
| UpResConvBlock(in_channel, out_channel, 1), | |
| ] | |
| modules += [nn.Conv1d(channels[-1], out_channels, 1), nn.GELU()] | |
| self.net = nn.Sequential(*modules) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = torch.tanh(self.net(x)) | |
| return x | |
| class VAE(L.LightningModule): | |
| def __init__( | |
| self, | |
| io_channels: int, | |
| io_features: int, | |
| latent_features: int, | |
| channels: Sequence[int], | |
| learning_rate: float, | |
| ) -> None: | |
| super().__init__() | |
| self.encoder = Encoder(io_channels, io_features, latent_features, channels) | |
| channels.reverse() | |
| self.decoder = Decoder(io_channels, latent_features, io_features, channels) | |
| self.latent_features = latent_features | |
| self.audio_loss_func = MultiResolutionSTFTLoss() | |
| self.learning_rate = learning_rate | |
| def sample(self, eps: Tensor = None) -> Tensor: | |
| if eps is None: | |
| eps = torch.rand((1, self.latent_features)) | |
| return self.decoder(eps) | |
| def loss_function( | |
| self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor | |
| ) -> Tensor: | |
| audio_loss = self.audio_loss_func(x, x_hat) | |
| kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) | |
| return audio_loss + kld_loss | |
| def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor: | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return eps * std + mean | |
| def forward(self, x: Tensor) -> tuple[Tensor]: | |
| mean, logvar = self.encoder(x) | |
| z = self.reparameterize(mean, logvar) | |
| return self.decoder(z), mean, logvar | |
| def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor: | |
| x_hat, mean, logvar = self.forward(batch) | |
| loss = self.loss_function(batch, x_hat, mean, logvar) | |
| if log: | |
| self.log("train_loss", loss, prog_bar=True) | |
| return loss | |
| def configure_optimizers(self) -> Optimizer: | |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) | |
| return optimizer | |
| class CVAE(L.LightningModule): | |
| def __init__( | |
| self, | |
| io_channels: int, | |
| io_features: int, | |
| latent_features: int, | |
| channels: Sequence[int], | |
| num_classes: int, | |
| learning_rate: float, | |
| ): | |
| super().__init__() | |
| self.class_embedder = nn.Linear(num_classes, io_features) | |
| self.data_embedder = nn.Conv1d(io_channels, io_channels, kernel_size=1) | |
| self.encoder = Encoder(io_channels + 1, io_features, latent_features, channels) | |
| channels.reverse() | |
| self.decoder = Decoder( | |
| io_channels, latent_features + num_classes, io_features, channels | |
| ) | |
| self.num_classes = num_classes | |
| self.latent_features = latent_features | |
| self.audio_loss_func = MultiResolutionSTFTLoss() | |
| self.learning_rate = learning_rate | |
| def sample(self, c, eps=None) -> Tensor: | |
| c = nn.functional.one_hot(c, num_classes=self.num_classes).float().unsqueeze(0) | |
| if eps is None: | |
| eps = torch.rand((1, self.latent_features)) | |
| z = torch.cat([eps, c], dim=1) | |
| return self.decoder(z) | |
| def loss_function( | |
| self, x: Tensor, x_hat: Tensor, mean: Tensor, logvar: Tensor | |
| ) -> Tensor: | |
| audio_loss = self.audio_loss_func(x, x_hat) | |
| kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) | |
| return audio_loss + kld_loss | |
| def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor: | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return eps * std + mean | |
| def forward(self, x: Tensor, c: Tensor) -> tuple[Tensor]: | |
| c = nn.functional.one_hot(c, num_classes=self.num_classes).float() | |
| c_embedding = self.class_embedder(c).unsqueeze(1) | |
| x_embedding = self.data_embedder(x) | |
| x = torch.cat([x_embedding, c_embedding], dim=1) | |
| mean, logvar = self.encoder(x) | |
| z = self.reparameterize(mean, logvar) | |
| z = torch.cat([z, c], dim=1) | |
| return self.decoder(z), mean, logvar | |
| def training_step(self, batch: Tensor, batch_idx: int, log: bool = True) -> Tensor: | |
| x, c = batch | |
| x_hat, mean, logvar = self.forward(x, c) | |
| loss = self.loss_function(x, x_hat, mean, logvar) | |
| if log: | |
| self.log("train_loss", loss, prog_bar=True) | |
| return loss | |
| def configure_optimizers(self) -> Optimizer: | |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) | |
| return optimizer | |